Compare commits

...

1198 Commits

Author SHA1 Message Date
Angelos Katharopoulos
870208eff5 Start sdpa vector 2025-06-16 17:38:39 -07:00
Awni Hannun
bc53f8293f Cuda bug fixes 2 (#2298)
* more bug fixes

* more bug fixes

* format
2025-06-16 13:14:46 -07:00
Awni Hannun
c552ff2451 [CUDA] Fix back-end bugs and enable corresponding tests (#2296)
* Fix some cuda back-end bugs and enable corresponding tests

* more fixes

* enable more tests

* format
2025-06-16 08:45:40 -07:00
Awni Hannun
4fda5fbdf9 add python testing for cuda with ability to skip list of tests (#2295) 2025-06-15 10:56:48 -07:00
Angelos Katharopoulos
580776559b RoPE for CUDA (#2293)
* First working CUDA rope

* Fix random
2025-06-15 06:08:07 -07:00
Awni Hannun
a14aaa7c9d Fix cuda arg reduce (#2291) 2025-06-14 17:54:00 -07:00
Awni Hannun
a6d780154f fix cuda gemm for bf16 (#2288) 2025-06-13 22:10:46 -07:00
Awni Hannun
6871e2eeb7 fix cuda jit (#2287) 2025-06-13 19:21:46 -07:00
Awni Hannun
8402a2acf4 Fix complex power and print (#2286)
* fix complex power and print

* fix complex matmul shape
2025-06-13 11:13:00 -07:00
Jagrit Digani
fddb6933e1 Collection of refactors (#2274)
* Refactor gemv into a function

* Refactor splitk step 1

* Refactor split k axpby

* Rearrange steel_gemm_regular

* Redirect steel_gemm_regular

* Add axpby routing to steel_matmul_regular

* Refactor AddMM step 1

* Redirect steel_gemm

* Update addmm

* Comments and format

* Some cleanup

* Add architecture gen to device

* Update no copy condition in normalization to account for axis size 1
2025-06-13 10:44:56 -07:00
Cheng
c8b4787e4e CUDA backend: indexing ops (#2277) 2025-06-12 21:44:19 -07:00
Awni Hannun
2188199ff8 [CUDA] ternary with select op (#2283)
* cuda ternary with select op

* comment + fix

* fix
2025-06-12 20:24:43 -07:00
Awni Hannun
aa07429bad Fix cuda build (#2284) 2025-06-12 17:48:05 -07:00
Awni Hannun
918761a25a [CUDA] RMSNorm and VJP (#2280)
* rms norm start

* nit
2025-06-12 17:09:49 -07:00
Cheng
a4fc671d3e CUDA backend: compile (#2276)
* CUDA backend: compile

* Rename kernels/ to device/
2025-06-12 17:08:39 -07:00
Awni Hannun
f5f65ef48c Make sliceUpdate general (#2282)
* Make sliceUpdate general

* fix
2025-06-12 16:48:54 -07:00
Cheng
c2dd81a8aa Fix warnings from latest CUDA toolkit (#2275) 2025-06-12 06:03:01 -07:00
Cheng
d7e680ffe4 CUDA backend: layernorm (#2271) 2025-06-11 15:48:32 -07:00
Cheng
c371baf53a CUDA backend: softmax (#2272) 2025-06-11 13:55:22 -07:00
Cheng
ccf78f566c CUDA backend: argreduce (#2270) 2025-06-11 13:26:17 -07:00
Cheng
c9fa68664a CUDA backend: reduce (#2269) 2025-06-11 11:22:25 -07:00
Awni Hannun
c35f4d089a start cuda circle config (#2256)
* rebase

* fix metal kernel linking issue on cuda

* start cuda circle config
2025-06-10 21:19:47 -07:00
Angelos Katharopoulos
8590c0941e Add load_safe to the general conv loaders (#2258) 2025-06-10 20:58:16 -07:00
Cheng
095163b8d1 Fix building cpp benchmarks on Linux (#2268) 2025-06-10 17:10:24 -07:00
Cheng
99c33d011d rebase + nit (#2260)
Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-10 10:51:51 -07:00
Awni Hannun
62fecf3e13 fix conv export (#2265) 2025-06-10 09:34:01 -07:00
Cheng
7c4eb5d03e CUDA backend: random (#2261) 2025-06-10 08:59:56 -07:00
Cheng
bae9a6b404 CUDA backend: sort (#2262)
Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-10 08:59:47 -07:00
Christopher Fleetwood
004c1d8ef2 Report number of missing parameters (#2264)
* chore: inform

* chore: format

---------

Co-authored-by: FL33TW00D <FL33TW00D@users.noreply.github.com>
2025-06-10 06:37:50 -07:00
Cheng
7ebb2e0193 CUDA backend: binary ops (#2259) 2025-06-10 06:37:40 -07:00
Awni Hannun
9ce77798b1 fix export to work with gather/scatter axis (#2263) 2025-06-09 20:37:27 -07:00
Cheng
f8bad60609 CUDA backend: unary ops (#2158) 2025-06-09 06:45:08 -07:00
Emmanuel Ferdman
5866b3857b Refactor the lu test (#2250)
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-06-07 06:12:08 -07:00
Awni Hannun
1ca616844b Fix unintuitive metal kernel caching (#2242)
* Fix unintuitive metal kernel caching

* alternative solution
2025-06-06 20:08:15 -07:00
Angelos Katharopoulos
2e8cf0b450 Change layernorms to two pass algorithm (#2246) 2025-06-06 13:34:56 -07:00
Cheng
24f89173d1 CUDA backend: matmul (#2241) 2025-06-06 12:24:04 -07:00
Awni Hannun
c6a20b427a Improve metal elementwise kernels (#2247)
* improve metal elementwise kernels

* compile and copy

* fix jit
2025-06-06 11:37:40 -07:00
Awni Hannun
a5ac9244c4 fix linux linking error (#2248) 2025-06-06 10:41:51 -07:00
Awni Hannun
c763fe1be0 default strict mode for module update and update_modules (#2239) 2025-06-05 15:27:02 -07:00
Cheng
52dc8c8cd5 Add profiler annotations in common primitives for CUDA backend (#2244) 2025-06-04 19:55:12 -07:00
Angelos Katharopoulos
aede70e81d Perf regression fix (#2243) 2025-06-03 17:55:12 -07:00
Cheng
85a8beb5e4 Avoid atomic updates across CPU/GPU in CUDA event (#2231) 2025-06-03 16:49:06 -07:00
Cheng
0bb89e9e5f Share more common code in Compiled (#2240)
* Share more common code in Compiled

* Remove build_lib_name
2025-06-03 16:48:50 -07:00
Cheng
5685ceb3c7 Avoid invoking allocator::malloc when creating CUDA event (#2232) 2025-06-03 16:48:40 -07:00
Suryash Malviya
0408ba0a76 Optimizing Complex Matrix Multiplication using Karatsuba’s Algorithm (#2220)
* Implementing Complex Matmul using Karatsuba Algorithm

* Implemented Karatsuba's Algorithm for complex matmul and pre-commit them

* fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-02 15:58:46 -07:00
Awni Hannun
cbad6c3093 version (#2237) 2025-06-02 15:58:33 -07:00
Cheng
1b021f6984 Fast primitives decide when to use the fallback (#2216) 2025-06-02 13:26:37 -07:00
Cheng
95b7551d65 Do not check event.is_signaled() in eval_impl (#2230) 2025-06-02 13:23:34 -07:00
Cheng
db5a7c6192 Add memory cache to CUDA backend (#2221)
* Move BufferCache out of allocator

* Add memory cache to cuda backend allocator

* Simplify BufferCache assuming buf can not be null
2025-05-30 12:12:54 -07:00
Awni Hannun
6ef2f67e7f 5bit quants (#2226)
* 5bit quants

* 5bit quants
2025-05-30 12:12:10 -07:00
Cheng
f76ee1ffd2 Move some dims utils to common (#2223) 2025-05-29 06:48:30 -07:00
Cheng
54a71f270a Remove unused defines (#2217) 2025-05-23 06:14:58 -07:00
Awni Hannun
55b4062dd8 copyright in docs (#2214) 2025-05-21 17:13:04 -07:00
Cheng
79071bfba4 Fix out-of-bounds default value in logsumexp/softmax (#2213) 2025-05-21 07:25:16 -07:00
Cheng
7774b87cbd Remove redundant simd_sum in logsumexp (#2210) 2025-05-21 07:25:03 -07:00
Cheng
35c87741cf Build for compute capability 70 instead of 75 (#2209) 2025-05-20 19:42:48 -07:00
Jack Wind
4cbe605214 Feat: Allow per-target Metal debug flags (#2201)
* feat: allow per-target Metal debug flags

* formatting fix
2025-05-20 10:22:26 -07:00
Clement Liaw
ab8883dd55 include mlx::core::version() symbols in the mlx static library (#2207) 2025-05-20 07:39:11 -07:00
Awni Hannun
eebe73001a fix large arg reduce (#2206) 2025-05-19 13:10:44 -07:00
Angelos Katharopoulos
0359bf02c9 Nearest upsample (#2202) 2025-05-19 11:23:38 -07:00
Cheng
237f9e58a8 Fix BEFORE keyword in target_include_directories (#2204) 2025-05-19 06:10:44 -07:00
Awni Hannun
8576e6fe36 fix conv2d bug + faster conv 1d (#2195)
* fix conv2d bug + faster conv 1d

* revert sort + flaky test
2025-05-18 06:05:11 -07:00
Angelos Katharopoulos
0654543dcc Add complex eigh (#2191) 2025-05-18 00:18:43 -07:00
Awni Hannun
48ef3e74e2 reduce vjp for all and any (#2193) 2025-05-16 08:38:49 -07:00
Cheng
7d4b378952 Include cuda_bf16.h for bfloat16 overloads (#2192)
* Include cuda_bf16.h for bfloat16 overloads

* Add NO_GPU_MULTI(Eig) in cuda backend
2025-05-16 06:44:42 -07:00
Jack Wind
7ff5c41e06 Add set_threadgroup_memory_length to CommandEncoder (#2183) 2025-05-16 00:28:03 -07:00
Awni Hannun
602f43e3d1 fix conv grad (#2187) 2025-05-15 19:20:36 -07:00
Awni Hannun
a2cadb8218 real and imag properties (#2189) 2025-05-15 18:17:50 -07:00
Awni Hannun
c1eb9d05d9 non-symmetric eig and eigh (#2188) 2025-05-15 13:01:44 -07:00
Angelos Katharopoulos
cf6c939e86 Fix some complex vjps (#2178) 2025-05-14 23:37:12 -07:00
Angelos Katharopoulos
130df35e1b Add random normal distribution for complex numbers (#2182) 2025-05-13 22:43:45 -07:00
Cheng
0751263dec Fix typo in row_reduce_small (#2179) 2025-05-13 20:19:54 -07:00
Cheng
eca2f3eb97 Add remove_index utility (#2173) 2025-05-13 17:09:56 -07:00
Angelos Katharopoulos
3aa9cf3f9e Fix put_along_axis for empty arrays (#2181) 2025-05-13 14:27:53 -07:00
Awni Hannun
8f3d208dce Close a couple edge case bugs: hadamard and addmm on empty inputs (#2177)
* handle hadamard and addmm on empty inputs

* fix
2025-05-12 10:48:57 -07:00
Ivan Fioravanti
caaa3f1f8c Small typos in mx.metal deprecations (#2176) 2025-05-11 06:03:47 -07:00
Awni Hannun
659a51919f patch bump (#2162) 2025-05-09 14:35:14 -07:00
Awni Hannun
6661387066 Fix fft for integer overflow (#2161) 2025-05-09 14:25:12 -07:00
ATurker
a7fae8a176 fix: conv_general differences between gpu, cpu (#2070)
* fix general_conv padding

* fix bugs

* add test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-05-09 10:26:52 -07:00
Cheng
0cae0bdac8 CUDA backend: backbone (#2075) 2025-05-06 21:26:46 -07:00
Awni Hannun
5a1a5d5ed1 fix input coherent kernel launch (#2153) 2025-05-05 17:30:50 -07:00
Cheng
1683975acf Move common gpu primitives to backend/gpu (#2145) 2025-05-05 13:45:29 -07:00
Awni Hannun
af705590ac fix batched vector sdpa (#2152) 2025-05-05 13:13:03 -07:00
Awni Hannun
825124af8f fix bw for elementwise ops (#2151)
* fix bw for elementwise ops

* add compile

* fix

* fix

* fix

* fix
2025-05-05 06:15:04 -07:00
Awni Hannun
9c5e7da507 fix compile merging (#2150) 2025-05-02 15:08:50 -07:00
Angelos Katharopoulos
481349495b GPU Hadamard for large N (#1879) 2025-05-01 17:19:17 -07:00
Awni Hannun
9daa6b003f fix shapeless export (#2148) 2025-05-01 15:02:02 -07:00
Angelos Katharopoulos
a3a632d567 Fix the launcher when ran locally (#2147) 2025-05-01 12:56:09 -07:00
Awni Hannun
e496c5a4b4 fix integer overflow in qmm (#2143) 2025-04-30 09:28:56 -07:00
Cheng
ea890d8710 Remove metal-only tests (#2139) 2025-04-30 09:08:39 -07:00
Awni Hannun
aa5d84f102 Allow quant layer to be unfrozen (#2142) 2025-04-30 09:08:29 -07:00
Awni Hannun
f1606486d2 Generalize gpu backend (#2138)
* generalize gpu backend

* fix no_gpu build

* fix no_gpu build

* generalize gpu backend
2025-04-30 09:08:17 -07:00
Cheng
87720a8908 Fix building with uv (#2141) 2025-04-30 06:04:07 -07:00
Aashiq Dheeraj
bb6565ef14 add fftshift and ifftshift fft helpers (#2135)
* add fftshift and ifftshift fft helpers

* address comments

* axes have to be iterable

* fix fp error in roll + add test

---------

Co-authored-by: Aashiq Dheeraj <aashiq@aashiq-mbp-m4.local>
2025-04-29 22:13:45 -07:00
Awni Hannun
7bb063bcb3 Enable vjp for quantized scale and bias (#2129)
* Enable vjp for quantized scale and bias

* higher tol
2025-04-29 13:03:09 -07:00
Alex Chi Z.
b36dd472bb return library if it is successfully loaded (#2131) 2025-04-29 07:30:36 -07:00
hdeng-apple
167b759a38 Fix typos (#2136) 2025-04-29 07:26:05 -07:00
charan-003
99b9868859 Clarify dimension notation in conv1d, conv2d, and conv3d docstrings (#2123)
* Clarify dimension notation in conv1d, conv2d, and conv3d docstrings

* Updating transposed convs in conv1d, conv2d, and conv3d

---------

Co-authored-by: Sai Charan Arvapally <saicharan@Sais-MacBook-Pro.local>
2025-04-25 12:18:30 -07:00
1ndig0
6b2d5448f2 Fix the error message in mx.right_shift and mx.left_shift (#2121)
* update right_shift and lef_shift

* simplify

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-04-25 09:14:28 -07:00
Awni Hannun
eaf709b83e patch (#2119) 2025-04-24 16:11:07 -07:00
Angelos Katharopoulos
f0e70afff0 Fix swift pm load (#2117) 2025-04-24 10:58:29 -07:00
hdeng-apple
86984cad68 Remove static initializers (#2059)
* Remove static initializers in device.cpp, load.cpp, pocketfft.h

* Remove static initializer InTracing::trace_stack

* Remove static initializer of CompilerCache cache

* Revert changes in pocketfft.h

* Remove duplicate private section of thread_pool()
2025-04-24 06:14:49 -07:00
Awni Hannun
fbc89e3ced fix pinv (#2110) 2025-04-23 13:08:28 -07:00
hdeng-apple
38c1e720c2 Search mlx.metallib in macOS framework "Resources" dir (#2061)
---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-04-23 09:53:13 -07:00
Param Thakkar
600e87e03c Added output_padding parameters in conv_transpose (#2092) 2025-04-23 09:26:33 -07:00
Hyunsung Lee
3836445241 Add broadcast_shapes in python API (#2091) 2025-04-22 18:57:39 -07:00
Yury Popov
1d2c9d6a07 Complex scan (#2094) 2025-04-22 18:56:28 -07:00
Awni Hannun
e8ac6bd2f5 irfft throws instead of segfaults on scalars (#2109) 2025-04-22 10:25:55 -07:00
Awni Hannun
fdadc4f22c Add more complex unary ops (#2101) 2025-04-21 13:04:54 -07:00
Awni Hannun
79b527f45f conv vmap (#2102) 2025-04-21 13:04:39 -07:00
Awni Hannun
dc4eada7f0 Use unordered map for kwargs in export/import (#2087)
* use unordered map for kwargs in export/import

* comment
2025-04-21 07:17:22 -07:00
Cheng
70ebc3b598 Return const ref in array::data_shared_ptr (#2100) 2025-04-21 07:17:09 -07:00
Cheng
b13f2aed16 Introduce macros for dispatching dynamic dtypes as static types (#2073) 2025-04-19 06:16:30 -07:00
Param Thakkar
5f04c0f818 Fixed shift operations issue (#2080)
* Fixed shift operations issue

* Added tests and fixes

* Fixed loop syntax error

* Added tests for bool

* Fixed typo
2025-04-18 14:28:33 -07:00
Awni Hannun
55935ccae7 fix py gc edge case (#2079) 2025-04-18 12:46:53 -07:00
Awni Hannun
b529515eb1 minor bump (#2081) 2025-04-17 14:57:11 -07:00
Angelos Katharopoulos
3cde719eb7 Route to gather qmm only for many tokens per expert (#2082) 2025-04-17 14:53:08 -07:00
Angelos Katharopoulos
5de6d94a90 Gather qmm batched kernel and refactoring of quantized (#2078) 2025-04-17 13:53:11 -07:00
Angelos Katharopoulos
99eefd2ec0 Gather mm new kernel and small refactoring (#2040) 2025-04-14 16:37:36 -07:00
Yury Popov
e9e268336b LogCumSumExp (#2069) 2025-04-13 01:27:29 -07:00
Awni Hannun
7275ac7523 Fix release build (#2072) 2025-04-12 20:41:58 -07:00
Angelos Katharopoulos
c4189a38e4 Add float mask to sdpa vector (#2068) 2025-04-11 17:29:40 -07:00
Awni Hannun
68d1b3256b nit: fix exception handling (#2066) 2025-04-11 14:12:08 -07:00
Awni Hannun
9c6953bda7 Fix stubgen (#2065)
* Fix stubgen

* add multi optim to docs
2025-04-11 12:02:54 -07:00
Awni Hannun
ef7ece9851 fix fft bug (#2062) 2025-04-10 19:41:27 -07:00
Angelos Katharopoulos
ddaa4b7dcb Fix the test and add custom min/max reductions for uncommon MPI types (#2060) 2025-04-10 17:01:17 -07:00
Cheng
dfae2c6989 Fix MSVC build due to use of M_LN2 (#2058) 2025-04-10 07:41:41 -07:00
Anastasiia Filippova
515f104926 Min / max reductions (#2041) 2025-04-09 23:22:20 -07:00
Angelos Katharopoulos
9ecefd56db Do not load the default lib if another is requested (#2055) 2025-04-09 13:31:38 -07:00
Awni Hannun
e5d35aa187 no sdpa in grad (#2054) 2025-04-08 19:13:54 -07:00
Awni Hannun
00794c42bc Fix causal mask sdpa vec (#2053)
* fix sdpa vector causal mask

* test
2025-04-08 09:11:23 -07:00
Cheng
08a1bf3f10 Remove Event::Signal() (#2052) 2025-04-08 06:20:27 -07:00
Awni Hannun
60c4154346 Only request residency once (#2051) 2025-04-07 10:47:51 -07:00
Awni Hannun
f2c85308c1 add a half simd gemm fallback (#2046)
* add a half simd gemm fallback

* nit
2025-04-07 09:31:29 -07:00
Awni Hannun
1a28b69ee2 only add to residency set once (#2049) 2025-04-06 17:38:25 -07:00
Cheng
ba09f01ce8 Remove test of converting negative float to uint (#2048) 2025-04-06 06:21:46 -07:00
Cheng
6cf48872b7 wait_for_one should wait for task to finish (#2047) 2025-04-05 20:05:16 -07:00
Angelos Katharopoulos
7b3b8fa000 Fix ci release (#2045) 2025-04-04 20:25:01 -07:00
Awni Hannun
ec5e2aae61 nit in doc (#2044) 2025-04-04 12:04:17 -07:00
Awni Hannun
86389bf970 patch bump (#2043) 2025-04-03 13:15:18 -07:00
Jagrit Digani
3290bfa690 Add new sdpa function overload (#2035)
* Add new sdpa function overload

* Address comments

* Remove std::varaint from cpp sdpa function
2025-04-03 11:58:28 -07:00
Jagrit Digani
8777fd104f Depthwise Conv2D optimization (#2036)
- Add new specialized kernel for small kernel (kernels size <= 7), small strides (strides <= 2) depthwise 2d convolutions
- Add related tests
2025-04-03 09:42:04 -07:00
Awni Hannun
c41f7565ed fix softmax / logsumexp (#2042) 2025-04-03 08:32:59 -07:00
Awni Hannun
9ba81e3da4 tune quant dispatch (#2031) 2025-04-02 20:05:54 -07:00
Awni Hannun
c23888acd7 Fix build warning (#2033) 2025-04-01 14:42:27 -07:00
Awni Hannun
f98ce25ab9 fix residency set for real (#2032) 2025-04-01 12:59:48 -07:00
Awni Hannun
de5f38fd48 Custom logsumexp (#2028)
* initial custom logsumexp

* more tests

* comments + fix
2025-03-31 07:36:55 -07:00
Angelos Katharopoulos
ec2854b13a Swap -inf for finite_minimum value (#2029) 2025-03-30 21:55:04 -07:00
Stephen Panaro
90823d2938 Add missing funcs to docs (#2021) 2025-03-30 18:29:33 -07:00
Jesper Stemann Andersen
5f5770e3a2 Fix CPU sign for unsigned ints (#2024)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-03-30 17:56:59 -07:00
Awni Hannun
28f39e9038 Log for complex numbers in Metal (#2025)
* Log for complex numbers in Metal

* fix log2
2025-03-30 17:04:38 -07:00
Awni Hannun
b2d2b37888 fix residency set clearing (#2027) 2025-03-30 16:27:26 -07:00
Awni Hannun
fe597e141c add pinv to doc (#2020) 2025-03-30 15:54:18 -07:00
Yi Wang
72ca1539e0 Remove unused variable in /setup.py (#2026)
This is a follow up of https://github.com/ml-explore/mlx/pull/2011
2025-03-30 12:52:33 -07:00
Awni Hannun
13b26775f1 use minimum deployment target (#2016) 2025-03-28 14:31:53 -07:00
Awni Hannun
05d7118561 causal vector sdpa (#2018)
* causal vector sdpa

* get rid of memory threshold
2025-03-28 12:36:13 -07:00
Awni Hannun
98b901ad66 enable complex gemm (#2017) 2025-03-28 10:45:13 -07:00
Awni Hannun
5580b47291 iinfo and scalar overflow detection (#2009) 2025-03-27 19:54:56 -07:00
Awni Hannun
bc62932984 sdpa specialization for head dim 256 (#2007) 2025-03-27 19:31:25 -07:00
Awni Hannun
a6b5d6e759 revise cmake minimum for doctest (#2014) 2025-03-27 19:30:58 -07:00
Yi Wang
a8931306e1 Remove unused variable in CMakeBuild (#2011)
Fix https://github.com/ml-explore/mlx/issues/2010
2025-03-27 16:00:51 -07:00
Yi Wang
fecdb8717e Polish CONTRIBUTING>md (#2005) 2025-03-25 19:06:34 -07:00
Awni Hannun
916fd273ea wire cache (#2006) 2025-03-25 18:54:01 -07:00
Yi Wang
0da8506552 Update docs for extensions (#2004) 2025-03-25 18:35:03 -07:00
Cheng
eda7a7b43e Do not join threads during process exit on Windows (#1738) 2025-03-25 06:33:08 -07:00
Chunyang Wen
022eabb734 Remove unused import (#1987) 2025-03-24 20:19:32 -07:00
Awni Hannun
aba899cef8 patch bump (#2000) 2025-03-24 12:47:05 -07:00
Jagrit Digani
6a40e1c176 Fix looping limit in causal attention (#1999) 2025-03-24 12:28:00 -07:00
Jesper Stemann Andersen
9307b2ab8b Fixed 32-bit platform support for distributed/ring implementation (#1996)
Replaced unsigned long integer literals with size_t literals in ring implementation, e.g., 1UL with size_t(1).
2025-03-24 08:08:40 -07:00
Jesper Stemann Andersen
522d8d3917 Added missing netinet/in.h include that fixes build on FreeBSD (#1997)
Defines IPPROTO_TCP.
2025-03-24 08:07:34 -07:00
Awni Hannun
a84cc0123f promote mask when needed (#1998) 2025-03-23 19:58:28 -07:00
Andrey Velichkevich
f018e248cd fix(backend): Include algorithm library in Allocator (#1992)
Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
2025-03-22 21:27:51 -07:00
Awni Hannun
cfd7237a80 fix docs (#1991) 2025-03-21 19:58:53 -07:00
Angelos Katharopoulos
4eef8102c9 Distributed layers (#1270) 2025-03-21 13:52:17 -07:00
Angelos Katharopoulos
69e4dd506b Add a ring all gather (#1985) 2025-03-21 13:36:51 -07:00
Angelos Katharopoulos
25814a9458 Disable mpi on version mismatch (#1989) 2025-03-21 13:36:26 -07:00
Awni Hannun
2a980a76ce Add stats and limit to common allocator and enable tests (#1988)
* add stats to common allocator and enable tests

* linux memory and default

* fix
2025-03-21 12:28:36 -07:00
Angelos Katharopoulos
d343782c8b Cross platform libmpi loading (#1975) 2025-03-21 11:23:10 -07:00
Awni Hannun
4e1994e9d7 move memory APIs into top level mlx.core (#1982) 2025-03-21 07:25:12 -07:00
jiyzhang
65a38c452b update the formula of smooth_l1_loss (#1986) 2025-03-21 06:25:23 -07:00
Awni Hannun
7b7e2352cd fix malloc or wait deadlock (#1976) 2025-03-20 16:48:43 -07:00
Awni Hannun
1177d28395 patch bump (#1981) 2025-03-20 15:12:22 -07:00
Awni Hannun
005e7efa64 fix mask in sdpa (#1980)
* fix mask in sdpa

* fix attention mask

* Re-enable routing for array mask

---------

Co-authored-by: Jagrit Digani <digani@apple.com>
2025-03-20 14:53:12 -07:00
Jagrit Digani
b42d13ec84 Update attention tests to show diff, disable array masks (#1978) 2025-03-20 14:25:38 -07:00
Jagrit Digani
9adcd1a650 Support fused masking in Attention (#1924)
* Update API to allow mask='causal' in fast::sdpa

* Add fallback

* Update steel::AttnParams

* Fix typo

* WIP, basic causal

* Update tests

* Update benchmarking

* Update masking loop limits

* Add bool masking and update tests

* Update additive mask

* Update benchmarks

* Update benchmarks

* Update tests

* Update for bfloat error

* Update early exit

* Add random seed to tests
2025-03-20 11:01:32 -07:00
Awni Hannun
3c164fca8c Fix multistream GPU deadlock (#1969)
* fix multistream GPU deadlock

* comments
2025-03-20 07:19:47 -07:00
jiyzhang
95e335db7b Update smooth_l1_loss in losses.py (#1974)
According the definition of smooth_l1_loss, the line 

diff = predictions - targets

Should be updated to 

diff = mx.abs(predictions - targets)

After the modification, the result is consistent with PyTorch smooth_l1_loss
2025-03-19 20:19:02 -07:00
Awni Hannun
f90206ad74 Guard nullptr dereference (#1972)
* guard nullptr dereference

* comment
2025-03-19 16:24:10 -07:00
Chunyang Wen
3779150750 refactor: all use schedule (#1973) 2025-03-19 11:24:04 -07:00
Cheng
0a9777aa5c Do not define MLX_VERSION globally (#1966) 2025-03-18 07:12:40 -07:00
Chunyang Wen
45ad06aac8 Fix typo; Fix lint warning when reuse the same name (#1968)
* Fix typo; Fix lint warning when reuse the same name

* Add missing period
2025-03-18 07:12:24 -07:00
Awni Hannun
c6ea2ba329 Use same accumulation precision in gemv as gemm (#1962)
* use same accumulation precision in gemv as gemm

* faster

* fix compile
2025-03-16 07:13:24 -07:00
Awni Hannun
2770a10240 fix grad with inplace updates (#1961) 2025-03-13 19:13:09 -07:00
Awni Hannun
d2a94f9e6a Only compile warnings as errors for circle (#1957) 2025-03-12 13:08:19 -07:00
Awni Hannun
32da94507a fix vmap for flatten (#1955) 2025-03-11 10:42:22 -07:00
Awni Hannun
736a340478 reduce binary size (#1952) 2025-03-11 06:30:44 -07:00
Awni Hannun
117e1355a2 fix copy for large arrays (#1953) 2025-03-10 15:04:25 -07:00
Awni Hannun
3c3e558c60 Support transposed head/seq for kv (#1950)
* support transposed head/seq for kv

* fix flaky test

* nit
2025-03-10 10:53:45 -07:00
Chunyang Wen
cffceda6ee Add type hint for _extra_repr (#1948) 2025-03-10 06:05:36 -07:00
Chunyang Wen
048805ad2c Remove unused modules (#1949) 2025-03-10 06:05:26 -07:00
Chunyang Wen
d14c9fe7ea Add file info when raising errors in save (#1943) 2025-03-08 14:51:04 -08:00
Chunyang Wen
5db90ce822 Fix obsured warning (#1944) 2025-03-08 14:50:39 -08:00
Chunyang Wen
d699cc1330 Fix unreachable warning (#1939)
* Fix unreachable warning

* Update error message
2025-03-07 17:23:04 -08:00
Awni Hannun
c4230747a1 redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
2025-03-06 19:23:38 -08:00
Awni Hannun
5245f12a46 always use json (#1938) 2025-03-06 15:35:56 -08:00
Chunyang Wen
a198b2787e Remove unused modules (#1936) 2025-03-06 14:20:27 -08:00
Chunyang Wen
04edad8c59 Add doc string for path (#1937) 2025-03-06 14:20:09 -08:00
David Wisdom
392b3060b0 Fix typo in randint docstring (#1932)
This commit fixes a typo in the docstring for mlx.core.random.randint() by changing "roadcastable" to "broadcastable".
2025-03-05 21:48:00 -08:00
Chunyang Wen
85b34d59bc Clean unused sys (#1929) 2025-03-05 13:48:03 -08:00
Awni Hannun
f599c11bc8 bump (#1931) 2025-03-05 13:16:53 -08:00
Angelos Katharopoulos
0792ff02ff Only fail when 10 consecutive socket errors occur (#1928) 2025-03-05 13:16:19 -08:00
Alex Barron
fd0d63ba5b Affine quant always in fp32 (#1925)
* do affine quant in fp32

* static cast
2025-03-04 17:50:19 -08:00
Abe Leininger
3835a428c5 Adds nuclear norm support (#1894)
* adjust norm unit test tolerance
2025-03-04 13:26:02 -08:00
Angelos Katharopoulos
9680f72cca Add a multi optimizer (#1916) 2025-03-04 13:16:35 -08:00
Angelos Katharopoulos
a0737273d3 Allow debugging in distributed mode (#1920) 2025-03-04 13:01:10 -08:00
Awni Hannun
e613d0eaf0 SDPA support for small batch (over sequence) queries (#1922)
* batch query sdpa

* batch sdpa for query
2025-03-04 10:59:04 -08:00
Awni Hannun
6bcd6bcf70 fix donation in scan (#1917) 2025-03-03 11:30:59 -08:00
Awni Hannun
ba12e4999a Use a heap for small sizes (#1911)
* use a heap for small sizes

* check if VM
2025-03-03 06:50:57 -08:00
Awni Hannun
4e7cd31d12 Fix slice data size (#1913)
* fix slice data size

* add test
2025-03-02 21:50:42 -08:00
Angelos Katharopoulos
5e6c130d93 RMS norm without scaling (#1915) 2025-02-28 20:26:57 -08:00
Angelos Katharopoulos
5d68082881 Ring docs (#1829) 2025-02-28 11:34:21 -08:00
Angelos Katharopoulos
607181644f Add mlx.distributed_config script (#1902) 2025-02-28 11:16:39 -08:00
Jagrit Digani
89d327075f Enabling fused attention for head dim 128 (#1899)
* Share KV smem

* Fix bfloat error

* Unroll O = S @ V loop

* Perf upgrade

* Remove commented out function

* Add -Wno-c++17-extensions flag to metal flags

* Add -Wno-c++17-extensions flag to metal extension flags
2025-02-26 10:02:06 -08:00
Angelos Katharopoulos
6bf00ef631 Fix ring of 2 and allow scalars in API (#1906) 2025-02-25 17:03:01 -08:00
Awni Hannun
7d042f17fe Double for lapack (#1904)
* double for lapack ops

* add double support for lapack ops
2025-02-25 11:39:36 -08:00
Awni Hannun
28b8079e30 fix double type promotion (#1901) 2025-02-25 06:00:53 -08:00
Awni Hannun
7face5d9fd fix cpu compile (#1897) 2025-02-24 14:10:30 -08:00
Awni Hannun
a44dc4bdb0 fix leaking objc (#1898) 2025-02-24 13:57:59 -08:00
Awni Hannun
2d0f384b6f fix simd erf_inv (#1896) 2025-02-24 13:57:47 -08:00
Awni Hannun
8ff84b5c43 fix version and expose command queue getter (#1892) 2025-02-20 15:25:15 -08:00
Angelos Katharopoulos
10b271d963 Ring update (#1885) 2025-02-20 14:32:31 -08:00
Jesper Stemann Andersen
0ebc8a3d25 Fixed issue where Clang on FreeBSD failed to compile mlx/backend/cpu/quantized.cpp (#1890) 2025-02-20 12:02:12 -08:00
Awni Hannun
bbda0fdbdb Allow non-square lu (#1889) 2025-02-20 08:13:23 -08:00
Jesper Stemann Andersen
c86422bdd4 Added mlx::core::version() returning std::string(MLX_VERSION) (#1819)
* Added version.h providing mlx::core::version() returning std::string(MLX_VERSION)

Also, added MLX_VERSION_MAJOR, MLX_VERSION_MINOR, MLX_VERSION_PATCH, MLX_VERSION_NUMERIC, and accompanying functions.

* Added version.h to mlx.h

* Changed version int functions to be constexpr

* Formatting

* Added handling of MLX_VERSION where only the prefix has major.minor.patch format

* Changed version function to be constexpr
2025-02-19 20:30:19 -08:00
Awni Hannun
c707b2b0a6 Limit compile buffers (#1887)
* limit compile buffers

* maybe not flaky test
2025-02-19 20:28:13 -08:00
Angelos Katharopoulos
78ba24c37d Raise an exception in the rope op if input is integer (#1884) 2025-02-19 14:43:39 -08:00
Angelos Katharopoulos
1a2cb72030 Ensure linspace always contains start and stop (#1883) 2025-02-19 13:53:20 -08:00
Abe Leininger
344a29506e Enforce triangular matrix form in tri_inv (#1876)
* fix tri_inv bug

* Revert "fix tri_inv bug"

This reverts commit b74b290201.

* Make sure that tri_inv returns a triangular matrix

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-02-19 12:42:33 -08:00
Angelos Katharopoulos
71de73a668 Fix convs by reverting #1803 (#1882) 2025-02-18 14:36:34 -08:00
Alex Barron
4c1dfa58b7 xor op on arrays (#1875) 2025-02-17 00:24:53 -08:00
Awni Hannun
5274c3c43f compiler warnings are errors (#1870) 2025-02-17 00:07:49 -08:00
Angelos Katharopoulos
1762793989 Remove unused uniform (#1867) 2025-02-14 15:51:41 -08:00
Awni Hannun
6cec78d8f2 bump (#1866) 2025-02-14 13:09:34 -08:00
Jagrit Digani
2dc307f2e6 Winograd Update for Small batches (#1803)
* Build in padding to Winograd kernels
* Add new fused Winograd kernel
* Enable weight flipping in Winograd kernels
2025-02-14 13:08:13 -08:00
Awni Hannun
7aea5b1895 Allow dynamic ops per buffer based on dispatches and memory (#1864)
* Allow dynamic ops per buffer based on dispatches and memory

* add initial arch values
2025-02-13 19:18:22 -08:00
Ronan Collobert
9733e16496 fix function pointer (#1865) 2025-02-13 18:46:11 -08:00
Alex Barron
7f2d1024f3 add f8_e4m3 loading (#1859) 2025-02-13 17:10:03 -08:00
Awni Hannun
428f589364 Revert "More buffer donation in some cases (#1858)" (#1863)
This reverts commit d274ae77f2.
2025-02-13 14:21:44 -08:00
Alex Barron
5cd97f7ffe Bitwise Inverse (#1862)
* add bitwise inverse

* add vmap + fix nojit

* inverse -> invert

* add to compile + remove unused
2025-02-13 08:44:14 -08:00
Awni Hannun
e425dc00c0 Faster small batch qmv (#1861)
* faster small batch qmv

* swap batch and block dims for qvm and qmv regular
2025-02-12 22:02:36 -08:00
Awni Hannun
d274ae77f2 More buffer donation in some cases (#1858)
* more donation

* fix

* add test
2025-02-12 19:41:37 -08:00
Alex Barron
55c5ac7820 fix int64 bug (#1860) 2025-02-12 19:23:46 -08:00
Angelos Katharopoulos
0145911bea Fixes output donation for IO ops on the GPU (#1857) 2025-02-12 10:52:30 -08:00
Awni Hannun
0a5215693e Fix grad copies (#1854)
* fix grad with copies

* add test

* add test
2025-02-11 15:26:42 -08:00
Awni Hannun
2a45056ba8 Cycle leak break (#1856)
* detect and break leaks in custom function

* detect and break leaks in custom function
2025-02-11 14:45:02 -08:00
Cheng
142b77751d Fix compilation error on Windows (#1844) 2025-02-10 19:53:05 -08:00
Abe Leininger
a5ededf1c3 CPU LU factorization and linear solvers (#1451)
* linalg solve backend

* nits

* more nits + fix

* luf primitive and lu, solve, and solve_triangular backends

* changes / nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-02-10 12:32:24 -08:00
Franck Verrot
7df3f792a2 Ensure Conv2D and Conv3D's kernel sizes aren't trimmed (#1852)
Before the change, this snippet:

```
print(nn.Conv1d(1, 32, 3, padding=1))
print(nn.Conv2d(1, 32, (3, 3), padding=1))
print(nn.Conv3d(1, 32, (3, 3, 3), padding=1))
```

would output:

```
Conv1d(1, 32, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)
Conv2d(1, 32, kernel_size=(3,), stride=(1, 1), padding=(1, 1), dilation=1, groups=1, bias=True)
Conv3d(1, 32, kernel_size=(3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=1, bias=True)
```

After the change, the output will be:

```
Conv1d(1, 32, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)
Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), dilation=1, groups=1, bias=True)
Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=1, bias=True)
```
2025-02-10 06:27:01 -08:00
Angelos Katharopoulos
9eb7d7362f Fix Split::vmap (#1845) 2025-02-08 09:22:13 -08:00
Awni Hannun
1c0c118f7c Fp64 on the CPU (#1843)
* add fp64 data type

* clean build

* update docs

* fix bug
2025-02-07 15:52:22 -08:00
Awni Hannun
1a1b2108ec bump (#1840) 2025-02-06 11:53:24 -08:00
Jagrit Digani
b6c6552d20 Add missing #pragma once (#1838) 2025-02-06 11:11:22 -08:00
Awni Hannun
83a0340fa7 allow command (#1836) 2025-02-06 10:32:24 -08:00
Nripesh Niketan
a62fc1b39f chore: pre-commit bump (#1837) 2025-02-06 08:55:01 -08:00
Awni Hannun
af1b725fda Fix a couple of slicing bugs (#1827)
* fix a few bugs

* fix conv grad

* speedup test

* comment
2025-02-05 19:50:08 -08:00
Awni Hannun
9174606d4c fix sort (#1835) 2025-02-05 17:16:27 -08:00
Awni Hannun
ca305afdbe loading empty list is ok when strict = false (#1834) 2025-02-05 16:19:27 -08:00
Awni Hannun
fe5987b81d faster sort (#1831) 2025-02-05 06:10:22 -08:00
Awni Hannun
a229c8cef0 don't duplicate malloc with custom kernel init (#1830) 2025-02-04 13:20:57 -08:00
Jesper Stemann Andersen
f6c0499b8d Resolved ambiguity in mlx::core::take_along_axis (#1822)
* Resolved ambiguity in mlx::core::take_along_axis

Detected by GCC 10 on riscv64-linux-gnu.

* Formatted

* Removed superfluous parentheses in random_tests.cpp
2025-02-04 06:06:17 -08:00
Awni Hannun
1156c84e86 Refactor common into cpu specific and truly common (#1817)
* refactor

* fix extension example

* fix no-cpu
2025-02-03 15:58:02 -08:00
Awni Hannun
ec7c7def40 no line buffer for mpi jobs (#1825) 2025-02-03 12:02:15 -08:00
Jesper Stemann Andersen
2d8e667400 MinGW support (#1806)
* Changed /bin/bash to bash for generating compiling preamble

* Fix wrt jit_compiler mingw like msvc wrt. WEXITSTATUS

* Solved ambiguity wrt. bernoulli test shape

* Disabled distributed/ring on Windows

* Fixed jit_compiler command wrt. MinGW

* Extended jit_compiler patch wrt. WEXITSTATUS to FreeBSD
2025-02-01 12:40:06 -08:00
Awni Hannun
80c863b972 Remove accelerate/ (#1816)
* remove accelerate

* comments

* neon reduction
2025-02-01 07:18:26 -08:00
Angelos Katharopoulos
f5cc1eea72 Allow different value dimensions in sdpa_vector (#1811) 2025-01-31 20:58:59 -08:00
Awni Hannun
b7c9f1d38f scatter axis + gather axis primitives (#1813)
* scatter axis + gather axis primitives

* add transforms

* comment
2025-01-31 20:48:08 -08:00
Awni Hannun
c6fc07f1f4 Unify CPU matmuls, remove unused accelerate conv (#1814)
* unify matmuls

* Update mlx/backend/common/matmul.cpp

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-01-31 14:43:37 -08:00
Angelos Katharopoulos
ded914f442 Small distributed launch helper (#1810) 2025-01-29 17:55:04 -08:00
Awni Hannun
4758c8baa1 Start to cleanup/unify accelerate and common back-ends (Part 1/N) (#1777)
* start to cleanup/unify accelerate and common back-ends

* more progress

* simplify

* add half type and allow infs in simd exp

* unify softmax + quantized, more dispatches to simd quantized mm

* add sin/cos, use simd in vector-scalar ops

* faster CPU vectorize quant

* faster erf/erfinv
2025-01-29 14:34:49 -08:00
Awni Hannun
7064fed1b1 Minor update on MPI docs (#1805) 2025-01-28 11:00:08 -08:00
Awni Hannun
1017ac4a9e add dilation for conv 3d layers + test for 3d conv w/ dilation (#1802) 2025-01-28 06:17:07 -08:00
Angelos Katharopoulos
ccb61d7aae Ring distributed backend (#1784) 2025-01-27 22:15:01 -08:00
Awni Hannun
2235dee906 catch stream errors earlier to avoid aborts (#1801) 2025-01-27 14:05:43 -08:00
Awni Hannun
28091aa1ff allow build python lib without specifying path (#1799) 2025-01-27 11:22:35 -08:00
Awni Hannun
121d9a0702 Fix rope fallback to not upcast (#1797)
* fix rope fallback to not upcast

* Update mlx/fast.cpp

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-01-26 19:07:21 -08:00
Nick
0cea88bcc5 Use @ matrix multiplication syntax to document matrix-matrix multiplication (#1793)
Co-authored-by: Nick Thompson <nicholas_a_thompson@apple.com>
2025-01-25 16:02:36 -08:00
Angelos Katharopoulos
72146fc4cd Einsum ellipsis (#1788) 2025-01-25 01:28:03 -08:00
Awni Hannun
e6a7ab9675 non square qr (#1783) 2025-01-21 14:07:47 -08:00
Angelos Katharopoulos
1f4c127fb9 Move some kernels to get_template_definition (#1782) 2025-01-21 08:59:44 -08:00
Awni Hannun
90532b1f37 recompile when shapeless is different (#1776) 2025-01-20 21:07:10 -08:00
Awni Hannun
a8666a757a fix shapeless compile on ubuntu24 (#1775) 2025-01-18 06:04:36 -08:00
Awni Hannun
a4667da1eb Faster synchronization Fence primitive (#1773)
* try faster synchronization

move event

fixes

update bench

fix

fix

* non-functioning kernel

* try alternative fence

* cleanup barrier

* get rid of event_fence

* update benchmarks

* doc string in metal fence
2025-01-17 18:42:19 -08:00
Awni Hannun
0c259961ac matmul jvps (#1772) 2025-01-17 10:36:26 -08:00
Awni Hannun
f288db8d34 Fix synchronization bug for in stream async works (#1768) 2025-01-15 06:07:34 -08:00
Awni Hannun
33421c1dd3 Limit grad recursion depth by not recursing through non-grad inputs (#1764)
* limit grad recursion depth

* add grad of module test
2025-01-14 14:33:18 -08:00
Nripesh Niketan
5cc5201914 feat: Add orthogonal initializer and corresponding tests (#1651)
* feat: Add orthogonal initializer and corresponding tests

* lint

* Add acknowledgements

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-13 07:29:20 -08:00
Awni Hannun
252e423e81 fix and cleanup event signal/wait for metal (#1765) 2025-01-10 18:37:26 -08:00
wrmsr
a4a2764a52 Fix broadcast_arrays python sig (#1763) 2025-01-10 12:33:26 -08:00
Cheng
ab8e832c18 0ul is not size_t on MSVC (#1762) 2025-01-10 12:33:11 -08:00
Angelos Katharopoulos
1ce0c0fcb0 Bump version (#1761) 2025-01-09 13:48:20 -08:00
Awni Hannun
657f466402 use sdpa and exportable functions in transformer multi head attention (#1760) 2025-01-09 13:11:55 -08:00
Alex Barron
c7b0300af5 Fix batched qmv bug (#1758) 2025-01-09 11:45:57 -08:00
Awni Hannun
da8c885784 Simplify removes no-ops from the tape (#1759)
* simplify removes no-ops from the tape

* comment
2025-01-09 11:23:19 -08:00
Awni Hannun
1ccaf80575 Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
2025-01-09 11:04:24 -08:00
Cheng
ec36bfa317 Include command stdout in error message (#1756)
* Include command stdout in error message

* On Windows pclose returns the exit code
2025-01-08 07:17:03 -08:00
Cheng
b8f76f717a Print exceptions in eval_cpu/eval_gpu and abort (#1754) 2025-01-08 06:31:09 -08:00
Awni Hannun
d1766f2c70 Add boolean mask support in vector SDPA (#1757) 2025-01-07 20:24:53 -08:00
Awni Hannun
516ded618b Dynamic slicing (#1741)
* dynamic slice and slice update

* python bindings + tests + fix set item

* fix compile issue

* comment

* fix jit
2025-01-07 14:02:16 -08:00
Jesper Stemann Andersen
c9c81d0584 Added additional missing unordered_map include that fixes build on FreeBSD (#1755) 2025-01-07 08:27:55 -08:00
Angelos Katharopoulos
545f84d905 Refactor distributed backend (#1752) 2025-01-06 17:33:15 -08:00
Awni Hannun
d5ec172c95 Allow boolean mask in sdpa (#1753)
* allow boolean mask in sdpa

* more permissive donation in ternary
2025-01-06 16:57:07 -08:00
Angelos Katharopoulos
25b3a3e541 Optionally specify names for arrays when exporting (#1749) 2025-01-06 13:07:46 -08:00
Awni Hannun
058d6ce683 mpi send use input as output (#1750)
* mpi send use input as output

* move earlier
2025-01-06 06:08:43 -08:00
Angelos Katharopoulos
eab93985b8 Update custom function docs (#1748) 2025-01-03 16:35:25 -08:00
Awni Hannun
b51d70a83c export docs (#1747) 2025-01-03 15:04:17 -08:00
Awni Hannun
259025100e Fix nd ternary on GPU (#1746) 2025-01-03 11:52:17 -08:00
Awni Hannun
c9d30aa6ac MLX in C++ example (#1736)
* MLX in C++ example

* nits

* fix docs
2025-01-02 19:09:04 -08:00
Angelos Katharopoulos
8544b42007 Add namespace (#1745) 2025-01-02 16:49:23 -08:00
Awni Hannun
6fa0501387 Fix concatenate/slice_update vjp + reduce binary size (#1735)
* fix concatenate vjp + reduce binary size

* also cast in slice update
2025-01-02 16:36:33 -08:00
Awni Hannun
ae69cb15e9 shapeless compile in docs and partially shapeless reshape (#1742) 2025-01-02 16:24:42 -08:00
Awni Hannun
a64a8dfe45 fix extension (#1740) 2025-01-02 16:16:16 -08:00
Venkata Naga Aditya Datta Chivukula
491fa95b1f Added Kronecker Product (#1728) 2025-01-02 16:00:34 -08:00
Danilo Peixoto
92ec632ad5 Fix Distributed Communication documentation (#1731)
* Add missing `size()` method call for group
2025-01-02 14:08:38 -08:00
Cheng
8ecdfb718b Fix export.cpp compilation with MSVC (#1737) 2024-12-29 06:56:30 -08:00
Awni Hannun
4ba0c24a8f Export / import functions to / from a file (#1642)
* export and import functions

* refactor + works for few primitives

* nit

* allow primitives with state

* nit

* nit

* simplify serialize / deserialize

* fix for constants

* python bindings

* maybe fix serialize failure case

* add example

* more primitives, training kind of works

* same result for python and c++

* some fixes

* fix export

* template it up

* some simplificatoin

* rebase

* allow kwargs and multiple functions

* exporter

* more primitives for exporting

* deal with endianness

* handle invalid stream

* add docstring
2024-12-24 11:19:13 -08:00
Cheng
935c8c4bb1 Make mx.compile work on Windows (#1697)
* Invoke MSVC on Windows in mx.compile

* Export kernel symbol on MSVC

* Remove unused template

* Parse env pairs in a robust way

* No need of cassert

* Remove unnecessary helpers

* Fix right trim

* Move command building to a separate file

* Missing header

* Do not pollute cwd with cl.exe

* Simplify str concat

* Pass output dir

* Fix styling
2024-12-24 07:02:33 -08:00
Valentin Roussellet
88f993da38 Explicit parentheses around some logical operators (#1732)
* fix some warnings

* format
2024-12-24 07:02:20 -08:00
Awni Hannun
ebfe64b92d shapeless slice update and broadcast when possible (#1727) 2024-12-23 11:25:15 -08:00
Awni Hannun
0308e9af71 Allow offset to be an mx.array for mx.fast.rope (#1724)
* allow offset for rope

* comment
2024-12-19 15:51:44 -08:00
Awni Hannun
c3628eea49 Add mx.finfo and use it when making causal mask (#1726)
* finfo

* fixes

* docs
2024-12-19 14:52:41 -08:00
Awni Hannun
e03f0372b1 More shape type (#1705)
* more shape type

* fix
2024-12-19 08:08:20 -08:00
Alex Barron
f17536af9c More lenient mask type check in SDPA (#1723)
* check mask type

* require promotion
2024-12-18 19:41:38 -08:00
Cheng
ed4ec81bca Link python extension with mlx statically on Windows (#1716)
* Link python extension with mlx statically on Windows

* More readable code
2024-12-18 19:26:04 -08:00
Awni Hannun
7480059306 track resource limit and throw if exceeded (#1718) 2024-12-18 18:45:58 -08:00
Awni Hannun
8bae22b0fa fix deletion of non-evaled arrays with siblings (#1714) 2024-12-18 18:45:36 -08:00
Alex Barron
49c34c4161 check mask type (#1721) 2024-12-18 14:25:18 -08:00
Awni Hannun
5548fcc96d fix synch race (#1719) 2024-12-18 12:25:16 -08:00
Cheng
070bd433ab Shorter kernel name for Windows (#1701)
* Shorter kernel name for Windows

* Only hash the clipped part
2024-12-17 18:51:38 -08:00
Cheng
c8fb54951a Define NOMINMAX before windows.h (#1715) 2024-12-17 18:51:24 -08:00
Awni Hannun
f110357aaa Bump nanobind to 2.4 + fix (#1710)
* bump nanobind to 2.4 + fix

* fix
2024-12-17 10:57:54 -08:00
Tomohiro Oga
a6b426422e add cubic to type hinting for upsample (#1709) 2024-12-17 07:30:23 -08:00
Awni Hannun
d03c01dfbc fix unflatten vjp (#1708) 2024-12-16 18:37:57 -08:00
Jesper Stemann Andersen
a82996e9fb io/load: Enabled pread implementation for mingw32 (#1706) 2024-12-16 07:20:45 -08:00
Cheng
af5a614aad Eval before cleanup so model file is unlocked (#1702) 2024-12-14 21:41:49 -08:00
Cheng
f9640e049d Install mlx.dll into the same dir with python bindings on Windows (#1690)
* Install mlx.dll into the same dir with python bindings on Windows

* Set BUILD_SHARED_LIBS for dlfcn-win32

* Update cmake requirements to 3.25

* Fix cmake style
2024-12-13 19:50:39 -08:00
Cheng
4768c61b57 Make sure gguf_ctx is closed when error happens (#1699) 2024-12-13 19:50:19 -08:00
Cheng
dfccd17ab9 Use psutil to get memory info on Windows (#1700) 2024-12-13 19:50:13 -08:00
Cheng
635117c5d4 Read/write files in binary mode (#1698) 2024-12-13 17:37:05 -08:00
Awni Hannun
50f3535693 Use expand_dims / unflatten / etc in more places (#1696)
* use expand_dims / unflatten in a couple more places

* few more

* few more

* fix
2024-12-12 17:00:44 -08:00
Awni Hannun
9111999af3 Fix small sort with metal validation (#1695) 2024-12-12 09:21:45 -08:00
Awni Hannun
6bd28d246e Allow no copy negative strides in as_strided and slice (#1688)
* allow no copy negative strides in as_strided and slice

* fix jit

* fix jit
2024-12-12 08:59:45 -08:00
Cheng
4d595a2a39 Make compiled preamble work in MSVC (#1675)
* Make compiled preamble work in MSVC

* Remove logging

* Only use powershell for MSVC
2024-12-12 08:55:49 -08:00
Awni Hannun
3a21f61772 Fix build (#1693) 2024-12-11 23:56:25 -08:00
Awni Hannun
4e1e9520e1 Flatten and unflatten (#1692)
* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
2024-12-11 21:51:37 -08:00
Cheng
0bf19037ca Remove "using namespace mlx::core" in python/src (#1689) 2024-12-11 15:45:39 -08:00
Awni Hannun
f3dfa36a3a Fix x86 tests (#1691)
* fix x86 tests

* comment
2024-12-11 07:47:18 -08:00
Cheng
4f9b60dd53 Remove "using namespace mlx::core" in benchmarks/examples (#1685)
* Remove "using namespace mlx::core" in benchmarks/examples

* Fix building example extension

* A missing one in comment

* Fix building on M chips
2024-12-11 07:08:29 -08:00
Awni Hannun
f76a49e555 ExpandDims primitive (#1687)
* add squeeze primitive

* simplify squeeze, use in gather

* fix

* fix

* fix

* fix

* fix no cpu

* use squeeze in matmul and friends

* expand dims primitive

* comment
2024-12-10 16:39:07 -08:00
Cheng
310ad8d9db Build OpenBLAS from source code for MSVC (#1674)
* Download OpenBLAS binaries when building with MSVC

* Download dlfcn-win32

* Link with dlfcn-win32 correctly

* Build OpenBLAS from source code

* Link with openblas statically

* Link with BLAS privately
2024-12-10 16:14:44 -08:00
Cheng
56db268f47 Provide a pread implementation for MSVC (#1666) 2024-12-10 15:55:53 -08:00
Cheng
92ab6bdeb8 Fix shared library not exporting symbols on Windows (#1684)
* Fix shared library not exporting symbols on Windows

* Function name style
2024-12-10 13:59:14 -08:00
Cheng
0070e360a1 Disable MSVC warnings (#1680) 2024-12-09 19:41:14 -08:00
Amethyst Shen
9df8fed046 Metal-cpp version bump (#1668)
* Metal-cpp version bump

Apple has released the stable version of Metal-cpp for macOS 15 and iOS 18. CMakeLists.txt is updated to build with it instead of the beta one.

* Fix style with cmake-format
2024-12-09 19:40:35 -08:00
Cheng
a59fae040f Fix library output directory for MSVC (#1681) 2024-12-09 19:07:50 -08:00
Awni Hannun
29a620cab2 No reshapes in quantized embedding (#1682)
* no reshapes in quantized embedding

* fix inadvertant cast

* add tol
2024-12-09 18:57:38 -08:00
Cheng
87d7a2520e Use Py_ssize_t in python bindings (#1678)
* Use Py_ssize_t in python bindings

* Args passed to std::max must be same type
2024-12-09 12:59:19 -08:00
Awni Hannun
40c62c1321 Use int64 stride everywhere (#1671)
* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
2024-12-09 11:09:02 -08:00
Awni Hannun
35b412c099 Fix compile hasher for string constants. (#1677)
* fix hash

* add test

* nit
2024-12-09 09:26:18 -08:00
Cheng
d0f471cff7 Using math defines requires switch in MSVC (#1665)
* Using math defines requires switch in MSVC

* Fix more math macros

* Fix type

* Remove _MSC_VER guard for math defines
2024-12-08 08:16:28 -08:00
Cheng
6f316b8bf5 Use int64_t instead of ssize_t (#1673) 2024-12-07 20:10:44 -08:00
Cheng
7c10c93a1f Convert filesystem path to std::string explicitly (#1672) 2024-12-07 20:10:06 -08:00
Cheng
d92ea094f1 Use && instead of and (#1663)
* Use && instead of and

* Remove "and" in ops.cpp
2024-12-07 18:26:39 -08:00
Cheng
6ae5423b4a Do not pass integers to isnan (#1664) 2024-12-07 18:26:23 -08:00
Cheng
9635cffdc8 Include io.h in MSVC for IO functions (#1661) 2024-12-07 18:26:06 -08:00
Cheng
96986fb362 Use auto* for pointers (#1662) 2024-12-07 18:25:40 -08:00
Cheng
3ceb341a75 Use correct complex type for MSVC (#1660) 2024-12-07 18:25:22 -08:00
Awni Hannun
50fa705125 patch bump (#1656) 2024-12-06 13:16:19 -08:00
Awni Hannun
69a2991614 allow compiling lambdas in C++ (#1650)
* allow compiling lambdas in C++

* fix test

* more tests

* auto detect capture-less lambda
2024-12-06 13:13:21 -08:00
mt_caret
fd3377dd1f Support bias correction in Adam and AdamW optimizers (#1640) 2024-12-06 12:13:34 -08:00
Awni Hannun
d0b6cb0425 More primitives for compiling with shapeless (#1653)
* more shapeless and more Shape

* more shape

* fix

* fix
2024-12-06 11:29:18 -08:00
Alex Barron
95c4a2e3af add back conditionaltype (#1655) 2024-12-06 11:12:01 -08:00
Awni Hannun
bc2a29f033 fix (#1654) 2024-12-06 10:48:58 -08:00
Nripesh Niketan
3bb5b4a302 Chore: Add default language in pre-commit and bump hooks (#1652) 2024-12-06 07:54:29 -08:00
Awni Hannun
fc88fd9097 Shape and Strides 1 / N (#1645)
* shape and stride type def

* more shape
2024-12-05 12:53:43 -08:00
Awni Hannun
c5b0928c1f fix fallback (#1646) 2024-12-05 11:59:53 -08:00
Awni Hannun
e047fd977d compile changes if stream changes (#1644) 2024-12-03 14:37:44 -08:00
Jagrit Digani
9d40e521d7 Stop matrix copies with new attention kernel (#1639) 2024-12-02 14:12:38 -08:00
Alex Barron
1445dcaa60 let class predicate specify quantization parameters (#1638) 2024-12-02 14:09:28 -08:00
Jesper Stemann Andersen
e4eeb4e910 Added missing unordered_map includes (#1635)
* Added missing includes in mlx/io.h and mlx/backend/metal/metal.h

* Added additional missing unordered_map includes that fixes build on FreeBSD
2024-12-02 07:03:03 -08:00
Awni Hannun
aa86876813 fix transformer decoder post norm LN (#1637) 2024-12-02 07:02:17 -08:00
Jesper Stemann Andersen
974bb54ab2 CMake: Enabled using Accelerate on x86_64 / x64 (#1625)
* CMake: Enabled using Accelerate on x86_64 / x64

Cf. https://github.com/JuliaPackaging/Yggdrasil/pull/9761

* CMake: Removed superfluous MLX_BUILD_ARM
2024-11-28 10:55:45 -08:00
Ikko Eltociear Ashimine
9bc2183a31 docs: update device.cpp (#1632)
unecessary -> unnecessary
2024-11-27 20:58:26 -08:00
Awni Hannun
d4b222b6d3 Fix some leaks and races (#1629)
* fix leak and fix potential race

* more leak fixes

* fix one more
2024-11-27 20:01:20 -08:00
Jesper Stemann Andersen
af2af818a6 Enables build for *-linux-musl (#1627)
Also contributes to being able to build for *-w64-mingw32.

Cf. https://github.com/JuliaPackaging/Yggdrasil/pull/9761
2024-11-27 13:14:24 -08:00
Jesper Stemann Andersen
698e63a608 CMake: Build with dlfcn-win32 to have dlopen etc. on win32 (#1628)
Cf. https://github.com/JuliaPackaging/Yggdrasil/pull/9761
2024-11-27 13:14:13 -08:00
Awni Hannun
211411faf2 fix large ops (#1620) 2024-11-24 09:17:10 -08:00
Awni Hannun
bb303c45a5 version (#1617) 2024-11-22 12:00:03 -08:00
Alex Barron
6f7986d592 Cleaner qmv/qvm (#1616) 2024-11-22 11:14:08 -08:00
Awni Hannun
7cbb4aef17 Doc fix (#1615) 2024-11-22 11:12:25 -08:00
Jagrit Digani
02bec0bb6d Matrix Attention kernel (#1610)
* Rough INIT

* [WIP]: Loading and Matmuls added

* [WIP]: Reductions and min working aligned kernel at headdim = 64

* [WIP] Added headdim 80 for testing

* [WIP] Update dispatch params for testing

* [WIP] Add support for unaligned seq lengths - still looks messy

* Update sdpa_benchmarks

* Update sdpa_benchmarks

* Update sdpa_benchmarks

* Enable gqa support

* Update benchmark and switch off 128 headdim

* Update headdim 128 tuning

* Remove older fast attention code. Write out O strided

* Disable hd=128 until further optimizations

* Enable bf16

* Fix data size bug

* Enable attn build outside of jit
2024-11-22 10:34:05 -08:00
Alex Barron
c79f6a4a8c 3 and 6 bit quantization (#1613)
* Support 3 and 6 bit quantization
2024-11-22 10:22:13 -08:00
Awni Hannun
0c5eea226b Reduce specializations (#1607)
* start of reduce specializations

* fix all reduce

* fix many dims

* fix

* non-jit tests clear

* cleanup instantiations

* cpu merges

* change dim specializations

* optimize

* fix jit

* fix jit

* use higher precision for integer sum+prod

* fixes
2024-11-21 19:53:00 -08:00
Awni Hannun
dcca0d7477 contiguous op / prim (#1612) 2024-11-21 19:51:49 -08:00
Cocoa
0d5e7716ad fix typo: accross -> across (#1609)
Signed-off-by: Cocoa <i@uwucocoa.moe>
2024-11-20 15:30:51 -08:00
Angelos Katharopoulos
d8c824c594 Formatting fixes (#1606) 2024-11-20 15:30:36 -08:00
Saanidhya
cb431dfc9f Adds 3D pooling (#1526) 2024-11-19 16:45:24 -08:00
Awni Hannun
61d787726a Fix view scalar bug segfault (#1603)
* fix view scalar bug

* fix view scalar bug

* one more fix
2024-11-19 10:54:05 -08:00
Angelos Katharopoulos
5e89aace9b Fix concatenate vmap (#1600) 2024-11-19 10:44:04 -08:00
Awni Hannun
2af7e8a9a6 fix cmake version (#1601) 2024-11-19 08:45:05 -08:00
Awni Hannun
2419edd5b2 Faster indexing math in a few kernels (#1589)
* wip: faster compiled kernels

* faster general unary with uint specialization

* index type in compiled, unary, binary, ternary, copy

* fix jit

* jit fix

* specialize gather + scatter

* nit in docs
2024-11-18 19:52:00 -08:00
Awni Hannun
bf481e8e5d Fix sibling leak (#1590)
* add test

* fix + test

* fix fix
2024-11-18 19:17:01 -08:00
Awni Hannun
9d7fa6b8e6 Use osx deployment target to pick Metal version (#1595)
* choose metal based on deployment target rather than system version

* nit

* unused compile def
2024-11-18 19:16:49 -08:00
Angelos Katharopoulos
073076ac7d 2-Pass Sdpa Inference Kernel (#1597) 2024-11-18 17:31:53 -08:00
Awni Hannun
9bd03dd9b4 More buffer donation with no-ops (#1591)
* more donation

* fix test

* fix build
2024-11-18 08:35:41 -08:00
Awni Hannun
6931f84412 fix dispatch threads for a few kernels (#1594) 2024-11-18 08:35:25 -08:00
xnorai
16ec0556a0 Allocate raw JSON metadata buffer on the heap, and limit its size (#1596)
* Allocate raw JSON metadata buffer on the heap, and limit its size to 1GiB

* Set the upper size limit for the header to 100K as in Rust safetensors
2024-11-18 07:22:51 -08:00
Awni Hannun
610af352d4 Dispatch bf16 at run time when using the JIT (#1584)
* Dispatch bf16 at run time when using the JIT

* fix extension

* fix extension build

* fix extension build

* Update utils.h
2024-11-15 16:54:36 -08:00
Awni Hannun
b35f1e3c9c fix donation in sdpa (#1587) 2024-11-13 17:21:13 -08:00
Awni Hannun
dfa0b9aab4 Cpu fast quantize (#1578)
* cpu quantize

* fix
2024-11-08 20:10:39 -08:00
Alex Barron
a4c47b0276 OOB QMV fix (#1579)
* fix oob access in qmv

* skip more

* fix small case
2024-11-08 17:59:45 -08:00
Alex Barron
111fefd5e9 Fix OOB access in qmv (#1577)
* fix oob access in qmv

* skip more
2024-11-08 15:41:30 -08:00
Awni Hannun
c1fe1ef081 Bfs width limit (#1568)
* width limit

* fix

* large limit

* put env vars in env namespace
2024-11-08 15:00:46 -08:00
Awni Hannun
8c34c9dac4 throw for invalid case and remove test (#1575) 2024-11-08 12:04:03 -08:00
Awni Hannun
91c0277356 fix per-example mask + docs in sdpa (#1574) 2024-11-08 11:51:15 -08:00
Awni Hannun
9f0d5c12fc Fully wrap the command encoder (#1572)
* fully wrap the command encoder

* use consistent style + fix extensions
2024-11-08 11:50:21 -08:00
Awni Hannun
59247c2b62 add groups in conv2d (#1569) 2024-11-07 13:57:53 -08:00
Awni Hannun
9a3842a2d9 fix (#1566) 2024-11-06 17:10:33 -08:00
Alex Barron
726dbd9267 v0.20.0 (#1565) 2024-11-05 12:37:57 -08:00
Awni Hannun
54f05e7195 Fix gather vmap (#1563)
* fix gather

* fix
2024-11-05 11:29:20 -08:00
Alex Barron
26be608470 Add split_k qvm for long context (#1564)
* Add splitk qvm

* configurable splitk

* tuning

* remove extra instantiation

* remove refactor

* separate test

* cpu tolerance
2024-11-05 11:25:19 -08:00
Angelos Katharopoulos
248431eb3c Reductions update (#1351) 2024-11-04 22:25:16 -08:00
Awni Hannun
76f275b4df error in rms for wrong size (#1562) 2024-11-04 13:24:02 -08:00
Awni Hannun
f1951d6cce Use fewer barriers (#1561)
* use fewer barriers

* comment
2024-11-04 10:26:49 -08:00
Angelos Katharopoulos
62f297b51d Sdpa fix (#1558) 2024-11-02 21:25:46 -07:00
Awni Hannun
09bc32f62f No extra reshape (#1557)
* no extra reshape

* lint
2024-11-02 19:07:20 -07:00
Chris Offner
46d8b16ab4 Fix vmap example in docs (#1556) 2024-11-02 17:44:14 -07:00
Chris Offner
42533931fa Fix typo "it's" -> "its" (#1555) 2024-11-02 06:06:34 -07:00
Awni Hannun
9bd3a7102f add python 3.13 to circle (#1553) 2024-11-01 20:55:35 -07:00
Alex Barron
9e516b71ea Add dispatchThreads to custom kernel doc (#1551)
* add dispatchThreads info

* update

* add link
2024-11-01 13:07:48 -07:00
Awni Hannun
eac961ddb1 patch (#1550) 2024-10-31 16:10:14 -07:00
Awni Hannun
57c6aa7188 fix multi output leak (#1548) 2024-10-31 09:32:01 -07:00
Awni Hannun
cde5b4ad80 patch (#1546) 2024-10-30 19:31:22 -07:00
Awni Hannun
4f72c66911 improvements to scatter / gather (#1541) 2024-10-30 19:30:54 -07:00
Jagrit Digani
960e3f0f05 Gemm update (#1518) 2024-10-30 19:30:28 -07:00
Awni Hannun
884af42da2 Fix thread group for large arrays (#1543)
* fix thread group for large arrays

* comment

* one more
2024-10-30 16:25:12 -07:00
Alex Barron
048fabdabd Fix vmap constant output size (#1524)
* use inputs to determine output size

* remove noop vmap tests
2024-10-30 16:16:53 -07:00
Léo
917252a5a1 Add favicon to docs (#1545)
* add sphinx's html_favicon config

* removed unneeded newline

* ran pre-commit hooks
2024-10-30 13:54:13 -07:00
Carlo Cabrera
1a992e31e8 Skip using Residency sets in VMs (#1537)
* Skip using Residency sets in VMs

Attempting to use residency sets in a VM throws[^1]

    libc++abi: terminating due to uncaught exception of type std::runtime_error: [metal::Device] Unable to construct residency set.

Not quite sure if this is the best fix, but it does make the error go
away.

Note that it was previously possible to run simple programs that used
mlx in a VM prior to 0eb56d5be0. See
related discussion at Homebrew/homebrew-core#195627.

[^1]: https://github.com/Homebrew/homebrew-core/actions/runs/11525831492/job/32105148462#step:3:56

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* change residency check

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-10-29 19:37:23 -07:00
Awni Hannun
d2ff04a4f2 fix format (#1539) 2024-10-28 18:29:14 -07:00
Awni Hannun
015c247393 change wino dispatch conditoin (#1534) 2024-10-28 11:13:44 -07:00
Awni Hannun
d3cd26820e Faster bits and bernoulli (#1535)
* faster bits and bernoulli

* fix bernoulli
2024-10-28 11:11:00 -07:00
Awni Hannun
91f6c499d7 fix (#1529) 2024-10-25 19:25:35 -07:00
Awni Hannun
35e9c87ab9 patch bump (#1528) 2024-10-25 13:13:23 -07:00
Awni Hannun
8e88e30d95 BFS graph evaluation order (#1525)
* bfs order

* try fix event issue
2024-10-25 10:27:19 -07:00
Awni Hannun
0eb56d5be0 Wired (#1510)
* expose residency sets as wire/unwire

* returns wired size

* fix

* runtime support check

* fix os check

* fix test

* fix no metal build

* docs

* nit

* nits in docs

* nits
2024-10-25 09:35:33 -07:00
Paul Hansel
f70764a162 Fix typo in build docs (#1522) 2024-10-24 20:55:06 -07:00
Awni Hannun
dad1b00b13 fix (#1523) 2024-10-24 19:17:46 -07:00
Venkata Naga Aditya Datta Chivukula
430ffef58a [Feature] Added Sparse Initialization (#1498)
Co-authored-by: Saanidhyavats <saanidhyavats@gmail.com>
2024-10-24 12:31:24 -07:00
Alex Barron
3d17077187 Add mx.array.__format__ (#1521)
* add __format__

* actually test something

* fix
2024-10-24 11:11:39 -07:00
Angelos Katharopoulos
c9b41d460f Working 64-bit scans (#1506) 2024-10-24 11:05:46 -07:00
xnorai
32972a5924 C++20 compatibility for fmt (#1519)
* C++20 compatibility for fmt

* Address review feedback

* Remove stray string

* Add newlines back
2024-10-24 08:54:51 -07:00
Dhruv Govil
f6afb9c09b Remove use of vector<const T> (#1514) 2024-10-22 16:31:52 -07:00
Kashif Rasul
3ddc07e936 Eigenvalues and eigenvectors (#1334)
* initial eigvalsh

* add compute_vectors

* add compute_vectors_

* return a pair

* add eigh to return only eigenvectors

* fixed typo

* merge merge Eighvalsh and Eigh into a single primitive

* use the same primate with the flag

* fix primatives

* use MULTI

* fix eval_gpu

* fix decleration

* rename EighPrimitive to Eigh

* tests

* tests

* fix rebase and format

* cleanup lapack

* format

* add cblas.h

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-10-22 12:18:48 -07:00
Awni Hannun
c26208f67d Remove Hazard tracking with Fences (#1509)
* remove hazard tracking

* with fence map

* no hazard tracking with fences

* nits

* fix fence retain

* cleanup

* fix quantized rebase
2024-10-21 19:33:32 -07:00
Alex Barron
d15fa13daf Batched Quantized Matmul + Fast Small QMV (#1503)
* add fast qmv for small dims

* fix test

* batched cpu

* add batched template param

* refactor metal quantized.cpp
2024-10-21 16:23:17 -07:00
Awni Hannun
58a855682c v0.19.0 (#1502) 2024-10-18 11:55:18 -07:00
Awni Hannun
92d7cb71f8 Fix compile (#1501)
* fix compile

* fix space
2024-10-18 11:06:40 -07:00
Angelos Katharopoulos
50d8bed468 Fused attention for single query (#1497) 2024-10-18 00:58:52 -07:00
Awni Hannun
9dd72cd421 fix gumbel (#1495) 2024-10-17 13:52:39 -07:00
Awni Hannun
343aa46b78 No more 3.8 (#1493) 2024-10-16 17:51:38 -07:00
Awni Hannun
b8ab89b413 Docs in ci (#1491)
* docs in circle
2024-10-15 17:40:00 -07:00
Awni Hannun
f9f8c167d4 fix submodule stubs (#1492) 2024-10-15 16:23:37 -07:00
Awni Hannun
3f86399922 Real and Imag (#1490)
* real and imag

* fix

* fix
2024-10-15 16:23:15 -07:00
LastWhisper
2b8ace6a03 Typing the dropout. (#1479) 2024-10-15 06:45:46 -07:00
Awni Hannun
0ab8e099e8 Fix cpu segfault (#1488)
* fix cpu segfault

* nit in tests
2024-10-14 16:17:03 -07:00
Awni Hannun
020f048cd0 A few updates for CPU (#1482)
* some updates

* format

* fix

* nit
2024-10-14 12:45:49 -07:00
Awni Hannun
881615b072 Faster metal compiled kernels + some fixes (#1486)
* bump mac tests to use py39

* work per thread for compiled kernels

* fixe for large arrays

* fix
2024-10-14 12:45:38 -07:00
Awni Hannun
0eef4febfd bump mac tests to use py39 (#1485) 2024-10-14 10:40:32 -07:00
Awni Hannun
b54a70ec2d Make push button linux distribution (#1476)
* try again

* try again

* try again

* try again

* try again

* try again

* try again

* try again

* .circleci/config.yml

* one more fix

* nit
2024-10-14 06:21:44 -07:00
Awni Hannun
bf6ec92216 Make the GPU device more thread safe (#1478)
* gpu stream safety

* comment

* fix
2024-10-12 17:49:15 -07:00
Awni Hannun
c21331d47f version bump (#1477) 2024-10-10 13:05:17 -07:00
Awni Hannun
e1c9600da3 Add mx.random.permutation (#1471)
* random permutation

* comment
2024-10-08 19:42:19 -07:00
Awni Hannun
1fa0d20a30 consistently handle all -inf in softmax (#1470) 2024-10-08 09:54:02 -07:00
Awni Hannun
3274c6a087 Fix array is_available race cases (#1468) 2024-10-07 19:13:50 -07:00
Angelos Katharopoulos
9b12093739 Add the roll op (#1455) 2024-10-07 17:21:42 -07:00
Awni Hannun
f374b6ca4d Bump nanobind to 2.2 (#1461)
* bump nanobind

* extension version for tests
2024-10-07 16:52:40 -07:00
Awni Hannun
0070e1db40 Fix deep recursion with siblings (#1462)
* fix recursion with siblings

* fix

* add test

* increase tol
2024-10-07 06:15:33 -07:00
Awni Hannun
95d04805b3 Fix complex power on Metal (#1460) 2024-10-06 19:58:30 -07:00
Awni Hannun
e4534dac17 Conv grad with groups + bugfix (#1449)
* fix bug in flipped conv with groups, start of grad for groups

* fix

* fix

* fix + test
2024-10-06 07:08:53 -07:00
Angelos Katharopoulos
fef3c4ec1d Fix mpi test in CI (#1456)
* Fix mpi test in CI

* Set bind to none
2024-10-06 06:09:17 -07:00
Awni Hannun
1bdc038bf9 fix argpartition + faster {arg} sorts / partitions (#1453) 2024-10-03 14:21:25 -07:00
Awni Hannun
5523d9c426 faster cpu indexing (#1450) 2024-10-03 13:53:47 -07:00
Angelos Katharopoulos
d878015228 Fix normalization check_input (#1452) 2024-10-03 13:26:56 -07:00
Cheng
5900e3249f Fix building on Linux (#1446) 2024-09-30 07:00:39 -07:00
Angelos Katharopoulos
bacced53d3 Fix row reduce with very few rows (#1447) 2024-09-29 20:00:35 -07:00
Lucas Newman
4a64d4bff1 Add support for grouped 1D convolutions to the nn API (#1444)
* Fix the weight shape for grouped convolutions from the nn API.

* Add tests.

* Pre-commit formatting.

* Add input validation.

* Use integer division instead of casting.

* docs

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-28 06:41:07 -07:00
Awni Hannun
b1e2b53c2d bump (#1445) 2024-09-27 13:53:02 -07:00
Awni Hannun
11354d5bff Avoid io timeout for large arrays (#1442) 2024-09-27 13:32:14 -07:00
Awni Hannun
718aea3f1d allow take to work with integer index (#1440) 2024-09-26 15:58:03 -07:00
Awni Hannun
5b6f38df2b Faster cpu ops (#1434)
* faster binary and cleaner copy

* use recursive template for other ops

* more cleanup

* fix from cleanup

* more clean

* fix binary

* use contiguous iterator

* add 3d

* nits

* fix

* fix?

* fix

* fix rebase
2024-09-26 09:19:13 -07:00
Awni Hannun
0b4a58699e Some overhead reductions in mx.fast.metal_kernel (#1437)
* some overhead reductions

* fix

* use +=

* use more +=
2024-09-25 17:25:21 -07:00
Awni Hannun
4f9f9ebb6f Faster Metal unary and binary for general case (#1431)
* faster unary and binary for general case

* update ternary + jit fix

* fix jit

* unary work per thread
2024-09-25 12:07:43 -07:00
Awni Hannun
afc9c0ec1b dtype is copy assignable (#1436) 2024-09-25 12:07:13 -07:00
Awni Hannun
195b429d99 Put along axis + fixe for partition grad (#1430)
* put along axis, fixes for partition grad

* zeros for arg reduce
2024-09-23 10:03:38 -07:00
Luke Carlson
2b878e9dd7 Create CITATION.cff (#1425) 2024-09-20 11:39:46 -07:00
Awni Hannun
67b6bf530d Optimization for general ND copies (#1421) 2024-09-17 17:59:51 -07:00
Nripesh Niketan
6af5ca35b2 feat: add cross_product (#1252)
* feat: add cross_product

* lint

* python binding

* refactor: Improve error message for cross_product function

* refactor: more close to numpy cross product

* refactor: improve error message for cross_product function

* finish

* fix acks

* allow old numpy

* doc

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-17 13:12:43 -07:00
Awni Hannun
4f46e9c997 More fixes for arrays with large sizes (#1405)
* compile works for big arrays when contiguous

* style

* nits in docs

* a bunch more stuff

* update jit

* update jit

* use constant for shapes and strides and remove elem_to_loc overload

* use kernel instantiation

* docs nits

* update binary and ternary

* comments
2024-09-17 12:46:31 -07:00
Awni Hannun
c6739ba7f3 Faster RNN layers (#1419)
* faster rnn

* use admm
2024-09-17 06:04:19 -07:00
Angelos Katharopoulos
914409fef9 Data parallel helper (#1407) 2024-09-16 18:17:21 -07:00
jjuang-apple
8d68a3e805 remove fmt dependencies from MLX install (#1417) 2024-09-16 13:32:28 -07:00
jjuang-apple
6bbcc453ef avoid using find_library to make install truly portable (#1416) 2024-09-16 13:21:32 -07:00
Awni Hannun
d5ed4d7a71 override class function (#1418) 2024-09-16 13:21:04 -07:00
Nripesh Niketan
669c27140d Chore: add pre-commit hook for cmake (#1362)
* reset and lint

* format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-16 12:53:01 -07:00
Max-Heinrich Laves
adcc88e208 Conv cpu improvements (#1410) 2024-09-15 18:45:10 -07:00
Awni Hannun
d6492b0163 fix clip (#1415) 2024-09-14 16:09:09 -07:00
Awni Hannun
b3f52c9fbe ensure io/comm streams are active before eval (#1412) 2024-09-14 06:17:36 -07:00
c0g
bd8396fad8 Fix typo in transformer docs (#1414) 2024-09-14 06:05:15 -07:00
Angelos Katharopoulos
d0c58841d1 Patch bump (#1408) 2024-09-12 16:44:23 -07:00
Angelos Katharopoulos
881f09b2e2 Allow querying the allocator for the buffer size (#1404) 2024-09-11 21:02:16 -07:00
Awni Hannun
8b30acd7eb fix module attribute set, reset, set (#1403) 2024-09-11 16:30:42 -07:00
Awni Hannun
02efb310ca Xcode 160 (#1384)
* xcode 16.0 with debug tests

* limit nproc for builds

* vmap bug

* assert bug

* run python tests in debug mode

* fix view, bool copies preserve bits'

* actual view fix
2024-09-10 15:15:17 -07:00
Awni Hannun
e7e59c6f05 Fix copying scalars by adding fill_gpu (#1402)
* fix copying scalars by adding fill_gpu

* Another copy scalar changed to fill

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-09-09 15:54:08 -07:00
Awni Hannun
3ae6aabe9f throw for certain cases of non captured inputs in compile (#1401) 2024-09-09 14:54:31 -07:00
xnorai
dc627dcb5e Replace the use of result_of_t with invoke_result_t (#1397)
* Fix C++20 incompatibility

* Fix C++20 incompatibility
2024-09-06 19:52:57 -07:00
Max-Heinrich Laves
efeb9c0f02 Transposed Convolution (#1245)
* initial implementation for conv_transpose

ran pre-commit

implemented conv_transpose

updated conv_general docstring

updated conv_general docstring

updated code comments

removed commented run_conv_checks

updated acknowledgments

added missing entry to ops.rst

added op to nn.layers

resolved merge conflicts

* removed ConvolutionTranspose primitive as suggested by reviewer

removed ConvolutionTranspose primitive as suggested by reviewer

* remove transpose flag, add another test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-06 19:52:38 -07:00
Awni Hannun
ba3e913c7a Simplifications for MLX C (#1396)
* simplifications for MLX C

* use vectors instead of map

* update examples
2024-09-06 19:16:50 -07:00
Awni Hannun
7cca1727af Fix slice data size (#1394)
* fix slice data size and add tests

* fix contiguous flag

* simplify stride and perform copy for non-contiguous arrays

* fix cpu

* comment
2024-09-04 19:10:43 -07:00
Bhargav Yagnik
11371fe251 Test to prevent bugs like #1386 (#1391)
* updated test_array for missing ops

* formatting changes
2024-09-04 17:24:30 -07:00
Awni Hannun
41c603d48a fix jit reduce (#1395) 2024-09-04 14:03:10 -07:00
Angelos Katharopoulos
969337345f Fix reduce edge case (#1389) 2024-09-01 21:37:51 -07:00
Awni Hannun
9592766939 add std as method (#1387)
* add std as method

* add std as method
2024-09-01 19:49:16 -07:00
Angelos Katharopoulos
58dca7d846 Fix copy in the sort primitive (#1383) 2024-08-31 08:32:14 -07:00
Awni Hannun
0d302cd25b Fix compiel with byte sized constants (#1381) 2024-08-30 17:24:35 -07:00
Alex Barron
da691257ec Fix overflow in quantize/dequantize (#1379)
* add 2d indices to prevent overflow

* use nthreads not out size
2024-08-30 13:32:41 -07:00
Angelos Katharopoulos
1600092e92 Patch bump (#1376) 2024-08-29 16:54:30 -07:00
Awni Hannun
dba2bd1105 Even Even Faster IO (#1374)
* even more faster io

* make reader pool static

* make python reader thread safe

* one more optimization
2024-08-29 16:05:40 -07:00
Alex Barron
28be4de7c2 Fix JIT reductions (#1373) 2024-08-28 16:39:11 -07:00
Awni Hannun
a6c3b38fba Async load (#1372)
* async load

* async load
2024-08-28 14:21:55 -07:00
Awni Hannun
fcb65a3897 Even Faster I/O (#1369)
* try multithreading for faster IO

* smaller batch size

* Account for pread returning less than size

* nit

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-08-28 11:49:07 -07:00
Saanidhya
4e22a1dffe In continuation to PR1243 to solve issue #1240 (#1365)
* Solves issue #1240

* Correction

* Update python/mlx/utils.py

* Update python/mlx/utils.py

---------

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-08-28 11:40:41 -07:00
Awni Hannun
291cf40aca Some fixes to typing (#1371)
* some fixes to typing

* fix module reference

* comment
2024-08-28 11:16:19 -07:00
Jeethu Rao
bd47e1f066 Fix neon_fast_exp and add more softmax tests (#1367) 2024-08-27 23:42:42 -07:00
Aditya Dhulipala
e6b223df5f Pinv (#875) 2024-08-27 23:06:12 -07:00
Angelos Katharopoulos
e64349bbdd Make eval just wait if all arrays are scheduled (#1368) 2024-08-27 17:01:22 -07:00
Angelos Katharopoulos
cdb59faea6 Adds send/recv ops in distributed (#1366) 2024-08-26 23:01:37 -07:00
Alex Barron
1d94ac3f90 Add optional headers to `mx.fast.metal_kernel` (#1358) 2024-08-26 21:45:45 -07:00
Awni Hannun
5f7d19d1f5 MPI ops in GPU stream for faster comms (#1356) 2024-08-26 15:12:50 -07:00
Awni Hannun
2fdf9eb535 Fix ternary for large arrays (#1359)
* fix ternary for large arrays

* fix
2024-08-26 11:22:27 -07:00
Awni Hannun
860d3a50d7 fix extension metal library finding (#1361) 2024-08-26 09:18:50 -07:00
Alex Barron
d1183821a7 int() and float() for mx.array (#1360) 2024-08-25 20:41:44 -07:00
Angelos Katharopoulos
8081df79be Fix boolean all reduce bug (#1355) 2024-08-24 10:09:32 -07:00
Nripesh Niketan
64bec4fad7 Chore: update pre-commit hooks (#1353)
* Chore: update pre-commit refs

* run pre-commit
2024-08-24 06:46:36 -07:00
Alex Barron
b96e105244 Add grid_sample example to metal_kernel docs (#1352)
* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel`

* add grid sample to docs

* zero_outputs -> init_value

* add missing header for linux
2024-08-23 18:24:16 -07:00
Awni Hannun
3b4d5484c7 Bump extension MLX version (#1350)
* Bump extension MLX version

* fix some docs nits
2024-08-23 12:38:34 -07:00
Alex Barron
684e11c664 patch (#1347) 2024-08-23 10:42:02 -07:00
Angelos Katharopoulos
b57a52813b Further reduction tuning (#1349)
* More reduction tuning
* Forgotten pdb
* Small column long row specialization
2024-08-23 10:35:25 -07:00
Alex Barron
da8deb2b62 fix bug with multiple attributes (#1348)
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-08-23 10:06:15 -07:00
Awni Hannun
98b6ce3460 Refactor reductions and fix scatter atomics for large sizes (#1300)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-08-22 16:03:31 -07:00
Awni Hannun
f9e00efe31 fix nanobind and stub gen in circle (#1346) 2024-08-22 14:07:27 -07:00
Alex Barron
0fd2a1f4b0 Custom Metal Kernels from Python (#1325)
* start

* simple kernels working

* restructure

* inverse example working

* docs + fixes

* missing file

* fix imports

* address comments

* add docs + fix test

* Review comments + refactor to a single function

* update docs

* remove hashing

* fix contig bug in test

* back to a class

* trailing whitespace

* fix tests

* match c++ and python apis

* add link + make args kw_only
2024-08-22 13:46:29 -07:00
Awni Hannun
df3233454d 2d gather specialization (#1339) 2024-08-22 10:48:24 -07:00
Awni Hannun
82db84b899 bump nanobind + fix extension (#1344) 2024-08-21 16:05:07 -07:00
Awni Hannun
8ae751d3da fix io (#1343)
* fix io

* fix io

* comment
2024-08-21 13:14:46 -07:00
Awni Hannun
d40e76809f Fix rope (#1340)
* add test

* fix rope

* fix test
2024-08-20 17:37:52 -07:00
Awni Hannun
bb1b76d9dc RoPE with frequencies as optional input (#1337)
* start rope with freq input

* rope with frequencies

* nits

* fix bug

* fix bug + test

* cleanup

* optional base
2024-08-19 18:30:50 -07:00
Angelos Katharopoulos
9d26441224 Fix contiguity check (#1336)
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-08-19 16:05:06 -07:00
Awni Hannun
f12f24a77c fix compiling with space in paths (#1332) 2024-08-15 16:39:24 -07:00
Awni Hannun
ae5b5cabfd Fix optimizer reloading from checkpoint (#1329)
* fix optimizer reloading from checkpoint

* comment
2024-08-15 07:33:23 -07:00
Awni Hannun
d0630ffe8c Read arrays from files faster (#1330)
* read faster

* faster write as well

* set default permission for linux

* comment
2024-08-14 20:09:56 -07:00
Alex Barron
99bb7d3a58 GPU mx.sign for complex64 (#1326) 2024-08-14 07:54:53 -07:00
Awni Hannun
63ae767232 fix transformer (#1327) 2024-08-13 16:04:26 -07:00
Awni Hannun
eaaea02010 Add isfinite (#1318)
* isfinite

* remove reduce test since fix is not complete
2024-08-13 14:49:28 -07:00
Bhargav Yagnik
a098bc92e0 Fix: Preserve input dtype in Dropout layer output (#1323)
* Fix: Preserve input dtype in Dropout layer output

- Modified Dropout implementation to ensure that the output dtype matches the input dtype.
- This resolves the issue #1321

* Update test cases in test_nn.py

- Revised test cases to align with updated dropout code
- Fixed assertion method: replaced self.assertTrue with self.assertEqual for accurate comparisons in test_nn.py -> test_rope, test_alibi and test_dropout,

* updated dropout.py
2024-08-13 11:54:21 -07:00
Awni Hannun
1086dc4db0 patch (#1320) 2024-08-12 16:13:33 -07:00
Brian Keene
19fb69e2ed Add memory_efficient_threshold kwarg to sdpa kernel (#1319)
Allows opt-in to memory efficient GPU shader at proscribed sequence
length.  Otherwise, utilizes aggregate MLX primitives for best latency.
2024-08-12 12:57:09 -07:00
Awni Hannun
9231617eb3 Move to nanobind v2 (#1316) 2024-08-08 17:17:46 -07:00
Alex Barron
32668a7317 CPU mx.linalg.cholesky_inverse and mx.linalg.tri_inv (#1307)
* add cholesky inv + tri inv

* always run tri_inv on cpu

* consistent naming
2024-08-08 15:18:02 -07:00
Angelos Katharopoulos
780c197f95 Fix test tolerance and patch bump (#1315) 2024-08-08 14:51:09 -07:00
Angelos Katharopoulos
eb8819e91e Revert variance to be numerically stable (#1314) 2024-08-08 13:35:02 -07:00
Awni Hannun
30bbea2f08 Add gemv masked to JIT plus some fixes (#1310)
* add gemv masked to JIT plus some fixes

* some cleanup

* add utils

* fix

* fix 2

* more cleaning

* fix

* remove unused mps matmul support

* one more nit

* revert
2024-08-07 13:38:07 -07:00
Alex Barron
635ccd9e25 Add "edge" mode to mx.pad (#1309)
* Add edge padding mode

* fix pad in pooling

* string arg instead of enum
2024-08-06 11:23:10 -07:00
nicolov
8c9f0278b9 Add vmap to scatter (#1200)
* Add vmap to scatter

* updates

* vmap updates + a few more tests

* bug fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-05 20:12:27 -07:00
Awni Hannun
58d0e199e1 add bfloat conv for windograd (#1306)
* add bfloat conv for windograd

* accumulate in fp32

* accumulate in fp32

* accumulate in bf16
2024-08-05 15:51:13 -07:00
Awni Hannun
10b5835501 fix creating array from bf16 tensors in jax / torch (#1305) 2024-08-01 16:20:51 -07:00
Awni Hannun
6c8dd307eb faster group norm (#1304) 2024-08-01 12:49:23 -07:00
Awni Hannun
43ffdab172 fix rope and random (#1301)
* fix rope and random

* comment
2024-07-31 16:18:25 -07:00
Awni Hannun
40b6d67333 Fixes for large arrays with a few ops (#1299)
* fixes for large arrays with a few ops

* fix bug

* fix all of copy
2024-07-30 17:18:39 -07:00
Alex Barron
c52d1600f0 Fused Affine Quantize/Dequantize ops (#1282)
* Add fast affine dequantize

* add full quantize kernel

* fused kernel with scale/bias computation

* fix docstring

* fix no jit error

* fix test

* test fix

* reduce fast api to only affine_quantize
2024-07-29 15:11:38 -07:00
Awni Hannun
aa1d6cadad Fix docs latex build and nits (#1297)
* fix docs latex build and nits

* fix stub gen and try to clean up building
2024-07-29 11:44:06 -07:00
Atakan Tekparmak
6e06e3a904 feat: Added "tanh" option to GELU approximation (#1268) 2024-07-28 09:07:56 +02:00
Yaroslav
8cfb9fc0b8 Update requirements.txt (#1291) 2024-07-26 12:59:52 -07:00
Awni Hannun
7b456fd2c0 Array api (#1289)
* some updates for numpy 2.0 and array api

* some updates for numpy 2.0 and array api

* fix array api doc
2024-07-26 10:40:49 -07:00
Awni Hannun
e9e53856d2 patch bump (#1287) 2024-07-25 11:42:09 -07:00
Anton Belov
5029894662 [Issue #1187] Add nan_to_num function initial attempt (#1247)
* initial attempt, working with wrong types

* not compiling; mx.float16 and mx.bfloat16 tests added

* fix nan to num

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-25 09:57:37 -07:00
Awni Hannun
baf9fa5f42 Einsum (#1269)
* einsum initial

* fix comma break

* sum axis was wrong

* small cleanups

* python binding

* changed bindings to resemble numpy

* remove todo comment

* comment changes

* add count of operands/inputs

* fail fast if operands list is empty

* ignore comma if no output

* einsum path matching numpy

* getting somewhere with path

* remove print

* it passes the first test

* moved einsum tests to seperate file

* seperated einsum path

* moved einsum naive

* remove space from equation

* fast fail if no operands passed

* update tests and remove printf

* small cleanup

* some more cleanups

* removed python helper file

* ack

* utilize std for finding min in vector

* duplicate def

* remove the tuple as it was unreadable

* moved einsum_naive back to ops

* remaining isn't needed

* avoid creating another set

* cleanup

* greedy path, start of naive einsum

* more einsum

* fix some bugs

* some more fixes, tests pass

* benchmark

* some simplify

* fix einsum and test

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>

* add a bunch more tests and fix a bunch more bugs

* some docs nits

---------

Co-authored-by: dc-dc-dc <dgcruz983@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-07-25 09:36:44 -07:00
Jagrit Digani
7f914365fd Fix GPU sort for large arrays (#1285)
* Fix GPU sort for large arrays
2024-07-24 14:37:10 -07:00
Paul Paczuski
ebd7135b50 Improve stability of BCE loss calculation for input probabilities close to or exactly 0 or 1 (#1280)
* Improve stability of BCE loss calculation

* Standardize comment

* Apply formatting with black via pre-commit

* Add usage recommendation to docstring

* Update python/mlx/nn/losses.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-07-24 08:38:22 -07:00
fgranqvist
50eff6a10a Implement sampling from laplace distribution. (#1279) 2024-07-24 15:15:37 +02:00
Alex Barron
c34a5ae7f7 Fix bfloat16 Hadamard (#1283)
* fix bfloat16 hadamard

* add scale

* review comments

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-07-23 14:54:43 -07:00
Awni Hannun
e2aa6ec8ae some fixes (#1281) 2024-07-23 11:49:05 -07:00
toji
6768c6a54a Adding missing type hints (#1243)
* added type hints for `run`, `tree_map` and `tree_map_with_path`

* fix lint

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-23 07:29:38 -07:00
Tim Gymnich
6307d166eb Fix overflow / underflow handling for expm1f (#1278)
* Fix overflow / underflow handling for expm1f

* update tests
2024-07-23 07:29:06 -07:00
Awni Hannun
1fba87b0df Fix leak with multi-output primitives (#1274)
* fix leak with multi-output primitives

* hopefully an actual fix
2024-07-23 06:34:18 -07:00
Awni Hannun
df124e018a fix gguf (#1273)
* fix gguf

* comment
2024-07-18 07:35:35 -07:00
Cheng
2f83d6e4b7 Do not release buffers on exit (#1142) 2024-07-15 15:12:24 -07:00
Feng Shijie
987785d8d7 Fix typo and missing header (#1266) 2024-07-15 08:20:24 -07:00
Awni Hannun
8c01a7893b minor fix in optimizer + docs (#1264) 2024-07-12 12:18:02 -07:00
Awni Hannun
218047c75a docs fixes (#1263) 2024-07-11 15:59:07 -07:00
Alex Barron
d0da74209b version bump (#1260) 2024-07-11 11:17:55 -07:00
Angelos Katharopoulos
5c1fa64fb0 Custom transforms (#1246) 2024-07-10 18:00:01 -07:00
Alex Barron
a3c287354f Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
2024-07-09 20:39:01 -07:00
Angelos Katharopoulos
03cf033f82 Fix reshape copy bug (#1253) 2024-07-07 21:37:00 -07:00
Alex Barron
bdb36c9a63 add zero vjps for bitwise ops and gather w.r.t. index (#1256) 2024-07-07 21:34:59 -07:00
Awni Hannun
20bb301195 CPU binary reduction + Nits (#1242)
* very minor nits

* reduce binary

* fix test
2024-06-28 13:50:42 -07:00
Awni Hannun
d6383a1c6a version bump (#1239) 2024-06-27 10:43:13 -07:00
Angelos Katharopoulos
b05bcfd27f Fixes segfault when compiling checkpointed functions (#1235) 2024-06-26 16:14:45 -07:00
Alex Barron
2615660e62 Fix strided sort bug (#1236)
* Use output strides in sort kernel

* fix zero strides bug
2024-06-26 14:32:11 -07:00
Awni Hannun
5b0af4cdb1 fix donation condition for compilation (#1237) 2024-06-26 09:04:05 -07:00
Jagrit Digani
8c2e15e6c8 Accelerate import updates for iOS (#1227)
* Update veclib and bnns includes to #include <Accelerate/Accelerate.h> for compatibility with ios

* Mark float literals in softmax.cpp to be float16_t for errors in ios

* Add arm neon vector operation guards

* Redirect to common backend for consistency
2024-06-26 09:01:50 -07:00
Awni Hannun
56c8a33439 Get metal version from xcode (#1228)
* get metal version from xcode

* typo

* fix
2024-06-26 07:02:11 -07:00
David Koski
4eef1e8a3e fix typo (#1215) 2024-06-24 13:36:35 -07:00
Alex Barron
95d11bda06 Fix NumPy 2.0 pickle test (#1221)
* fix numpy version <2 temporarily

* typo

* better fix

* Fix just for bfloat16

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-23 05:47:22 -07:00
Awni Hannun
af9079cc1f version bump (#1212) 2024-06-14 11:28:51 -07:00
Jagrit Digani
2d6cd47713 Masked gemv (#1211) 2024-06-14 09:52:26 -07:00
Awni Hannun
fe3167d7ea smaller CPU binary (#1203)
* smaller CPU binary

* fix no cpu build
2024-06-14 09:46:55 -07:00
Awni Hannun
31e134be35 Build for macOS 15 (#1208)
* Build for macos 15

* metal32 as well

* comment

---------

Co-authored-by: Awni Hannun <Awni Hannun>
2024-06-13 13:31:44 -07:00
Awni Hannun
e84ba8056d only allow openmpi (#1209) 2024-06-13 12:14:44 -07:00
Fangjun Kuang
f20e97b092 minor fixes (#1194)
* minor fixes

* fix build errors
2024-06-12 22:06:49 -07:00
Alex Barron
934683088e Refactor JIT for unary/binary/ternary ops (#1206)
* refactor unary/binary/ternary ops

* get_primitive_string util

---------
2024-06-12 14:22:12 -07:00
Awni Hannun
de2b9e7d0a Fix kernel deps to reduce build times (#1205) 2024-06-12 11:17:39 -07:00
Alex Barron
dd7d8e5e29 Add Quantized Ops to the JIT (#1204)
* JIT for quantized ops

* remove unused imports

* address comments

* fix imports

* second attempt to fix imports

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-12 09:47:12 -07:00
Awni Hannun
df964132fb fix scatter + test (#1202)
* fix scatter + test

* fix test warnings

* fix metal validation
2024-06-11 14:35:12 -07:00
Awni Hannun
709ccc6800 install mpi for release build (#1199) 2024-06-10 10:09:32 -07:00
Awni Hannun
cf236fc390 version (#1191) 2024-06-06 17:16:40 -07:00
Alex Barron
27d70c7d9d Feature complete Metal FFT (#1102)
* feature complete metal fft

* fix contiguity bug

* jit fft

* simplify rader/bluestein constant computation

* remove kernel/utils.h dep

* remove bf16.h dep

* format

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-06 12:57:25 -07:00
nicolov
0e585b4409 Add docstring for scatter (#1189)
* Add docstring for scatter

* docs nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-06-06 11:51:25 -07:00
Angelos Katharopoulos
0163a8e57a Add docs for the distributed namespace (#1184) 2024-06-06 11:37:00 -07:00
Awni Hannun
578842954c fix jit scan when output doesn't have primitive (#1190) 2024-06-06 07:24:58 -07:00
Awni Hannun
496315fe1d Fix scan (#1188)
* fix scan

* improve grid size

* fix cpu cummax
2024-06-05 14:21:58 -07:00
Angelos Katharopoulos
0fe6895893 Fix the hard-shrink test (#1185) 2024-06-04 16:22:56 -07:00
Nikhil Mehta
0b7d71fd2f Add softmin, hardshrink, hardtanh (#1180)
---------

Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
2024-06-04 15:48:18 -07:00
Awni Hannun
83b11bc58d Fix Metal API validation for empty concat (#1183) 2024-06-04 13:17:08 -07:00
Alex Barron
375a8bbdcc Add some internal GPU apis (#1177)
* Add unary/binary/ternay/slice/concat internal GPU ops

* add pad internal op

* formatting + no_cpu fix
2024-06-04 09:24:26 -07:00
Awni Hannun
ea9090bbc4 Add view op (#1179)
* add view primitive

* nit

* fix view
2024-06-04 08:05:27 -07:00
nicolov
81def6ac76 Fix benchmark (#1175) 2024-06-04 07:50:46 -07:00
Angelos Katharopoulos
3de8ce3f3c In place all-reduce and forgiving init (#1178) 2024-06-03 16:47:47 -07:00
Alex Barron
4d485fca24 Add defines include (#1176)
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-03 09:50:10 -07:00
Brian Keene
1865299a30 Metal shaders for memory efficient self attention on large sequences (#964)
* Metal shaders for efficient self attention on large sequences

Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction

* more compiler silencing

* Address rebase issues

* Templatize kernel instantiation, revise cpu bindings

* Safer writes to output

* Permit batch size > 1

* Numerical fixes for sdpa self attention

* Re-enable test, remove unused variable

* add benchmarking script

* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
2024-06-03 09:16:19 -07:00
Dominik Schlösser
3576b547c5 Doc error for default for scale in SinusoidalPositionalEncoding (#1174) 2024-06-02 13:42:45 -07:00
Awni Hannun
079882495d version bump (#1172) 2024-05-31 12:29:12 -07:00
K Venkat Ramnan
ab977109db feat: Added dlpack device (#1165)
* feat: Added dlpack device

* feat: Added device_id to dlpack device

* feat: Added device_id to dlpack device

* doc: updated conversion docs

* doc: updated numpy.rst dlpack information

* doc: updated numpy.rst dlpack information

* Update docs/src/usage/numpy.rst

* Update docs/src/usage/numpy.rst

---------

Co-authored-by: Venkat Ramnan Kalyanakumar <venkatramnankalyanakumar@Venkats-MacBook-Air.local>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-05-31 12:29:01 -07:00
Awni Hannun
fd1c08137b stable cumprod grad at 0 (#1167) 2024-05-31 12:28:42 -07:00
Jagrit Digani
76b6cece46 Fix multi-block sort stride management (#1169)
* Fix multi-block sort stride management

* Add seed to tests
2024-05-31 11:10:54 -07:00
Jagrit Digani
9f0df51f8d Fix matvec vector stride bug (#1168) 2024-05-29 12:18:28 -07:00
Awni Hannun
e7a2a3dcd1 Fix a couple bugs (#1161)
* fix jit reduce for RMS norm

* make strides a single buffer

* better eval error message

* fix compiling with inf and bf16

* fix cpu compile with bf16
2024-05-28 15:18:18 -07:00
Awni Hannun
a87ef5bfc1 fix broadcast bug in bitwise ops (#1157) 2024-05-24 11:44:40 -07:00
Awni Hannun
9f9cb7a2ef version bump (#1154) 2024-05-23 18:08:08 -07:00
Awni Hannun
7e26fd8032 Option to JIT steel gemm / conv (#1139) 2024-05-23 18:07:34 -07:00
Jagrit Digani
eab2685c67 Float mask update (#1152)
* Float mask update

* Update CPU impl
2024-05-23 17:20:44 -07:00
Angelos Katharopoulos
50dfb664db Comms (#1097)
* Start the communications branch using MPI
* Add ops and primitives
* Add python bindings for distributed
2024-05-23 17:04:02 -07:00
Awni Hannun
0189ab6ab6 More jitting (#1132)
* docs + circle min size build

* jit scan, arange, softmax

* add sort

* jit reductions

* remove print

* fix deps

* clean includes / nits
2024-05-23 16:23:44 -07:00
Rifur13
9401507336 Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations.

Also fixed 1D grouped convs with different kernel strides and added more tests.

* fix channels condition
2024-05-22 20:01:44 -07:00
Awni Hannun
eb8321d863 list based indexing (#1150) 2024-05-22 15:52:05 -07:00
Abe Leininger
79ef49b2c2 add mx.trace (#1143) (#1147)
* working c++ trace implementation

* updated throw + added overloads

* added python binding for trace function

* pre-commit reformatting

* add trace to docs

* resolve comments

* remove to_stream call
2024-05-22 15:50:27 -07:00
Awni Hannun
e110ca11e2 Fix offset bug for device buffers (#1151)
* fix bug with large offsets for buffers

* add a test

* remove test as its too big for small machine
2024-05-22 15:50:05 -07:00
Awni Hannun
226748b3e7 JIT compile option for binary minimization (#1091)
* try cpp 20 for compile

* unary, binary, ternary in jit

* nits

* fix gather/scatter

* fix rebase

* reorg compile

* add ternary to compile

* jit copy

* jit compile flag

* fix build

* use linked function for ternary

* some nits

* docs + circle min size build

* docs + circle min size build

* fix extension

* fix no cpu build

* improve includes
2024-05-22 12:57:13 -07:00
Awni Hannun
d568c7ee36 Rename block sparse (#1149)
* block_sparse_mm to gather_mm

* rename

* nit

* nit
2024-05-22 07:48:34 -07:00
Awni Hannun
e6fecbb3e1 Some fixes in docs (#1141)
* fixes in docs

* nit
2024-05-20 11:51:47 -07:00
Angelos Katharopoulos
da83f899bb Improve qvm speed (#1140) 2024-05-20 09:20:44 -07:00
jlwitthuhn
7e5674d8be Treate 'minimum' differently in cosine decay (#1138) 2024-05-20 08:00:48 -07:00
Shixian Sheng
0a558577bf Update README.md (#1136) 2024-05-20 06:16:40 -07:00
Awni Hannun
fb71a82ada Fix copy bug with many dims (#1137) 2024-05-17 21:10:03 -07:00
Awni Hannun
23406c9e9e Choose the right MLX bf16 for extensions (#1135)
* default to custom bf

* choose right bf

* fix extensions

* fix circle conf
2024-05-17 15:09:28 -07:00
Luca Arnaboldi
b3ec792380 Implemented Cholesky on CPU (#1119) 2024-05-17 12:31:59 -07:00
Awni Hannun
6a9b584f3d patch bump (#1131) 2024-05-16 20:51:33 -07:00
Awni Hannun
81dd33af66 allow conversion to dlpack (#1120) 2024-05-16 16:11:37 -07:00
Awni Hannun
8b76571896 Fix extensions (#1126)
* fix extensions

* title

* enable circle

* fix nanobind tag

* fix bug in doc

* try to fix config

* typo
2024-05-16 15:36:25 -07:00
Angelos Katharopoulos
e78a6518fa Block sparse qmm (#1124) 2024-05-16 15:24:14 -07:00
Awni Hannun
1873ffda01 Detect metal version and propagate correctly for JIT (#1109)
* detect metal version and propagate correctly for JIT

* remove softmax

* fix versions
2024-05-15 17:42:09 -07:00
Jacket
c417e42116 [Fix] minor typo in default argument for argpartition's "axis" parameter (#1125)
According to the document, argpartition's axis parameter can be None, but due to a previous typo it can't really accepts a None value.
2024-05-15 15:25:25 -07:00
Jagrit Digani
358e1fd6ab Fused GEMM (#1123)
* Basic gemm working

* Update addmm

* Clear out steel_gemm and steel_addmm kernels

* Fuse and clear out gather gemm

* Update objc releases
2024-05-15 10:30:41 -07:00
Awni Hannun
631dfbe673 fix scatter index bug (#1122) 2024-05-14 15:04:58 -07:00
Cheng
56a4eaed72 Pass missing stream arg in array.flatten (#1111) 2024-05-14 06:50:16 -07:00
Cheng
bf925d9dc7 Move args in conv_general (#1118)
Also fix a typo that padding_lo is passed as padding_hi.
2024-05-14 06:50:09 -07:00
Cheng
1a7ed5dcb6 Fill vector with constructor instead of fill_n (#1113) 2024-05-14 06:28:55 -07:00
Cheng
5be5daa6ef Use compiled function in Sigmoid module (#1116) 2024-05-14 06:25:57 -07:00
Cheng
60cb11764e Use correct module type in quantized.py (#1115) 2024-05-14 06:25:42 -07:00
Cheng
cbd5445ea7 The tile op does not accept None as reps (#1117) 2024-05-14 06:25:25 -07:00
Cheng
2c7e9b5158 Add missing docs for some ops (#1110) 2024-05-14 06:09:05 -07:00
Mike Drob
2263e4b279 Experiment with medium machines for CI (#1000) 2024-05-13 19:40:19 -07:00
Awni Hannun
863039da4c Allow scatter type exception to be caught by checking in op (#1077)
* allow exception to be caught in main thread

* only for gpu

* more detailed scatter error
2024-05-13 17:43:53 -07:00
Awni Hannun
7178ac0111 No CPU option for binary minimization (#1105)
* no cpu build option

* docs

* fix
2024-05-13 16:08:11 -07:00
Ravindra R. Jaju
e7f9710499 Fix typo in a variable name in example code. (#1104)
* Fix typo in a variable name in example code.

* Rename df2dx2 to d2fdx2 - the appropriate naming for the second derivative

* Update CONTRIBUTING.md - add needed python packages, and a virtual-env hint

* Revert "Fix typo in a variable name in example code."

This reverts commit bc10a17534.

* Rename df2dx2 to d2fdx2
2024-05-13 06:04:23 -07:00
Max-Heinrich Laves
ff4223904d Conv3d (#993)
* added conv3d

added conv3d

implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D

* incorporated reviewer comments

* fixed test

* reduced tensor shapes in test for conv3d

* Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion
2024-05-11 06:15:02 -07:00
Awni Hannun
a9f80d60f6 improve error messaging in eval (#1101) 2024-05-10 10:04:07 -07:00
Alex Barron
2e158cf6d0 Add conjugate operator (#1100)
* cpu and gpu impl

* add mx.conj and array.conj()

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-05-10 07:22:20 -07:00
Awni Hannun
8bd6bfa4b5 version (#1099) 2024-05-09 17:52:39 -07:00
Awni Hannun
8b1906abd0 Add compiler flags to disable safetensors and gguf (#1098)
* with docs

* nit
2024-05-09 17:39:44 -07:00
Awni Hannun
06375e6605 Split encoders in non-concurrent context with a max ops per encoder (#1085)
* split encoders

* fix race
2024-05-09 16:21:02 -07:00
Awni Hannun
b21242faf1 Allow unary ops to accept array like (#1093) 2024-05-09 09:36:02 -07:00
Rahul Yedida
cc05a281c4 Added ArcTan2 operation (#1079)
* Added ArcTan2 operation

* Cleanup, bug fixes from code review

* Minor cleanup, fixed Linux tests
2024-05-08 08:35:15 -07:00
Jagrit Digani
fe96ceee66 Update block offset adjustment to be in size_t (#1087) 2024-05-08 08:10:23 -07:00
Awni Hannun
9814a2ae12 fix conversion to array (#1070) 2024-05-06 16:02:49 -07:00
Shubham
6992498e7a add keyword positonal (#1081) 2024-05-06 07:18:49 -07:00
Awni Hannun
21623156a3 Reset peak memory (#1074)
* reset peak memory

* fix linux

* nits in docs
2024-05-03 17:12:51 -07:00
Nripesh Niketan
79c859e2e0 feat: implement clip_grad_norm (#1043)
* feat: implement `clip_grad_norm`

* pre-commit

* Add test for clip_grad_norm function in test_optimizers.py

* small fixes

* fix

* lint

* Update tree_reduce

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Refactor clip_grad_norm function to include documentation and improve readability

* format docstring

* Add acknowlegements

* text wrap

* pre-commit

* nits in docs

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-05-03 09:07:02 -07:00
Awni Hannun
b00ac960b4 change initial memory limits and add memory size to device info (#1064) 2024-05-03 06:50:15 -07:00
Awni Hannun
02a9fc7bfa Patch bump (#1067)
* version

* use 0.12.2
2024-05-02 16:37:31 -07:00
Jagrit Digani
f390957685 Block sparse mm (#1058) 2024-05-02 14:03:58 -07:00
Angelos Katharopoulos
17f57df797 Improvements in the quantizer and dequantization kernel (#1061) 2024-05-01 18:19:11 -07:00
Awni Hannun
7f7b9662ea Fix leak for multi-output primitives which are never detached (#1059)
* fix multi output leak

* ignore arrays that will be detached

* add some comments

* stray print
2024-05-01 07:31:45 -07:00
Awni Hannun
19bef39f5c Add a mx.metal.device_info (#1060)
* device inof

* add variant

* fix linux

* fix doc
2024-04-30 15:47:27 -07:00
Nripesh Niketan
a30e7ed2da feat: metal formatting and pre-commit bump (#1038)
* feat: metal formatting and pre-commit bump

* add guards

* update

* more guards

* more guards

* smakk fix

* Refactor instantiation of ternary types in ternary.metal

* fix scan.metal
2024-04-30 07:18:09 -07:00
Angelos Katharopoulos
8db7161c94 Bug fix in quantize (#1054) 2024-04-29 20:55:04 -07:00
Awni Hannun
09f1777896 fix slice update indexing (#1053) 2024-04-29 12:17:40 -07:00
Jacket
490c0c4fdc [Fix] expand axes for dimension with integer indices in mlx_slice_update (#1035)
* Not sure if this is correct

* Format

* Edit tests

* Add negative test

* Format

* add one more test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-29 07:57:28 -07:00
Rifur13
c4a471c99d Add groups to Conv1d (#948)
* Add conv1d grouped convs on CPU

* Add GPU support

* Parallelize inside metal kernel

* clenaup

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* New unfold kernel + remove unused code

* Remove copy and refactor

* Update vjp and reuse steel gemm

* Fixed groups on cpu

* Fix metal validation

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-04-27 06:24:57 -07:00
Awni Hannun
86f495985b Add bitwise ops (#1037)
* bitwise ops

* fix tests
2024-04-26 22:03:42 -07:00
Awni Hannun
67d1894759 fix order device -> scheduler (#1039) 2024-04-26 13:46:41 -07:00
Awni Hannun
5bfe89bdb1 Cpp docs (#1036)
* start of C++ docs

* fix stream doc

* only include ops for now
2024-04-26 12:56:05 -07:00
Angelos Katharopoulos
82463e9938 Bump the version to 0.12 (#1034) 2024-04-25 14:18:08 -07:00
Awni Hannun
771575d27b Expose function to clear memory cache (#1032)
* expose function to clear memory cache

* fix linux build

* fix metal tests
2024-04-24 16:48:51 -07:00
Angelos Katharopoulos
20a01bbd9f Simplifying and improving qmm (#1030) 2024-04-24 13:07:45 -07:00
Angelos Katharopoulos
ec8578d41a Fix quantization of all 0s (#1028) 2024-04-24 00:40:42 -07:00
Aneesh Shetty
d0dbfe0b97 Adds radians and degrees (#1011) 2024-04-22 11:17:49 -07:00
Awni Hannun
3d405fb3b1 Add synchronize function (#1006)
* add synchronize function

* fix linux

* fix linux

* fix and fix docs

* fix test

* try synchronize in stream destroy

* synchronize works for both cpu and gpu
2024-04-22 08:25:46 -07:00
Angelos Katharopoulos
b0012cdd0f Bump the patch version for the quants (#1018) 2024-04-19 20:28:34 -07:00
Angelos Katharopoulos
84d61d27aa Make sure 0 is represented in the quantization (#1016) 2024-04-19 19:47:26 -07:00
Awni Hannun
ed83908931 fix gguf loading quants (#1014)
* fix gguf loading quants

* fix nanobind install

* actual fix
2024-04-19 12:24:07 -07:00
Angelos Katharopoulos
ef5f7d1aea Fix buffer protocol buffer size designation (#1010) 2024-04-19 06:06:13 -07:00
Awni Hannun
090ff659dc bump (#1007) 2024-04-18 13:18:43 -07:00
Jagrit Digani
85c8a91a27 Fix mask broadcasting bug and add relevant test (#1003) 2024-04-17 17:33:48 -07:00
Piotr Rybiec
581b699ac9 avgpool, not maxpool (#1002) 2024-04-17 08:26:22 -07:00
Awni Hannun
8a0677d56d Shared events for synchronization + async eval (#998)
* more async eval

* fix rebase

* try correct async eval

* fix async

* more tests for async eval

* use shared events for synchronization

* comment + cleanup

* with autorelease pool

* fix no metal build

* fix compile

* fix patch

* don't eval if asyn evale'd

* don't use is_evaled

* comments

* more multi stream tests

* try and cleanup use of is_evaled

* use a status flag
2024-04-17 06:16:02 -07:00
Jagrit Digani
b18468bf81 Masked mm (#978)
* Add block masked matmul op and primitive
2024-04-16 14:45:39 -07:00
Shiyu
107ba2891a gelu tanh approx (#989)
* gelu tanh approx

* gelu tanh approx

* replace gelu approx with tanh approach

* fix comments

* fix comment
2024-04-15 19:49:00 -07:00
Awni Hannun
cd9e184529 Quantize embedding (#994)
* quantize embedding

* rename as_linear + comment

* consistency in docs

* fix test
2024-04-15 16:42:10 -07:00
Alex Barron
2e7c02d5cd Metal FFT for powers of 2 up to 2048 (#915)
* add Metal FFT for powers of 2

* skip GPU test on linux

* fix contiguity bug

* address comments

* Update mlx/backend/metal/fft.cpp

* Update mlx/backend/metal/fft.cpp

* fix bug in synch

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-11 21:40:06 -07:00
Awni Hannun
ae18326533 No copy command encoder (#986)
* no copy command encoder

* up layer norm test tolerances
2024-04-11 21:15:36 -07:00
Alex Shepard
91eba8e485 fix for grammatical typo in docs (#988)
thanks for mlx!
2024-04-11 17:02:06 -07:00
Awni Hannun
d07e295c62 bumpity bump (#987) 2024-04-11 12:48:52 -07:00
Angelos Katharopoulos
dce4bd74a4 Add ArrayDesc destructor to avoid possible stack overflow (#982) 2024-04-11 11:37:02 -07:00
Nripesh Niketan
ffff671273 Update pre-commit hooks (#984) 2024-04-11 07:27:53 -07:00
Awni Hannun
12d4507ee3 Explicit barriers with concurrent dispatch (#977) 2024-04-10 21:45:31 -07:00
Awni Hannun
8580d997ff Try a stack-based DFS for eval (#980)
* rebase

* nit

* fix eval in vmap
2024-04-10 17:05:13 -07:00
Shiyu
061cf9a4ce Upsample with bicubic interpolation (#967) 2024-04-10 15:47:22 -07:00
Awni Hannun
99abb9eff4 Async eval (#972) 2024-04-09 18:34:00 -07:00
Luca Arnaboldi
fffe072028 Implementation of mlx.random.multivariate_normal (#502) (#877)
* Implementation of mlx.random.multivariate_normal (#502)

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Updated typo in docstring

* Restricted multivariate_normal to  float32

* Generic mean and variance shapes

* Review edits

* Update mlx/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Test for ndim of mean and cov

* nits

* smaller size for test

* fix broadcasted sampling

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-09 13:50:12 -07:00
Abe Leininger
a1a31eed27 Add mx.meshgrid (#961) 2024-04-09 11:43:08 -07:00
Awni Hannun
ae812350f9 use string (#976) 2024-04-09 11:22:00 -07:00
Awni Hannun
b63ef10a7f Extensions (#962)
* start to fix extensions

* mostly fixed extensions

* fix extension build

* couple more nits
2024-04-09 08:50:36 -07:00
Awni Hannun
42afe27e12 std and expm1 (#973)
* std and expm1

* actually add expm1

* fix linux

* fix vjp

* relax tol for linux test

* Add it to the compilable primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-08 14:26:01 -07:00
Awni Hannun
76e63212ff Enable bfloat scan (#974)
* enable bfloat scan
* fix tests
2024-04-08 12:29:19 -07:00
Awni Hannun
aac2f9fb61 Improve profiling with gpu tracing (#969)
* improve profiling with gpu tracing

* fix for linux

* nit

* doc fix

* fix example
2024-04-07 21:47:43 -07:00
Awni Hannun
bddf23f175 patch bump (#956) 2024-04-04 11:56:37 -07:00
Awni Hannun
039da779d1 No quant reshape (#957)
* precise option on cpu

* remove print

* remove reshape in quant matmul

* no quant reshape
2024-04-04 11:52:12 -07:00
Awni Hannun
d88d2124b5 segfaut layer norm grad (#955) 2024-04-04 10:59:15 -07:00
Awni Hannun
e142aaf8a1 Option for precise softmax (#953)
* precise softmax

* Add an equivalency check

* Make the threadgroup memory definition fixed

* precise cpu softmax

* precise option on cpu

* remove print

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-04 08:32:35 -07:00
AmirHossein_Razlighi
0caf35f4b8 Better exceptions in case of invalid operations on mlx.core.array (#910) (#926)
* Nicer exceptions for ops on non-arrays
2024-04-02 21:11:24 -07:00
Angelos Katharopoulos
3fc993f82d Properly handle negative axes in python vmap (#944) 2024-04-02 18:07:23 -07:00
Awni Hannun
741eb28443 fix a couple bugs (#952) 2024-04-02 12:07:41 -07:00
Angelos Katharopoulos
1a87dc5ea8 Fix compile fusion for multi-output edge cases (#950)
* Fix compile fusion for multi-output edge cases

* Add a test for multi-output compile
2024-04-02 08:42:31 -07:00
Awni Hannun
2427fa171e Fix cpu compile (#934)
* fix one cpu bug, test for another

* format hooks

* simplify contiguity check for cpu compile

* fix

* add back donation

* comment
2024-04-01 17:37:12 -07:00
Jagrit Digani
639e06e1f3 Indexing bug fix (#947)
* Fix axes accounting

* Add tests
2024-04-01 12:18:50 -07:00
Angelos Katharopoulos
02fedbf1da Fix array initialization from list (#942)
* Fix array initialization from list

* Change the error message in the test
2024-04-01 06:27:52 -07:00
Angelos Katharopoulos
110d9b149d Layer norm grad fix donation bug (#941)
* add layer norm grad test

* Fix donation bug in layernorm vjp

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-01 06:15:50 -07:00
Angelos Katharopoulos
9cbff5ec1d Fix typo in qmm check (#940) 2024-03-31 19:15:44 -07:00
Suvan Kumar
433c0206b0 Update saving_and_loading.rst (#929)
Update saving / load docs.
2024-03-30 14:30:06 -07:00
Awni Hannun
8915901966 Donation bug (#933)
* donation

* buf

* fix bug in softmax

* comment

* remove print
2024-03-30 10:08:54 -07:00
AmirHossein_Razlighi
f48bc496c7 Comparing python objects (such as list/tuple) with mlx.core.array (#920)
* add implicit conversion of list to array for equality constraint

* add tests for array equality

* add test for tuple and array equality

* return False if __eq__ arg is list or tuple

* write tests for equality

* update the rule of comparison for __ge__/__gt__/__lt__/__le__

* add a helper function for detecting mlx.core.array

* return true in case fo inequality

* debug minor issue regarding detecting mlx array

* add tests for inequality comparisons

* add name for contribution

* reformat files using pre-commit

* update tests for float

* update tests for inequality

* raise exception in case of invalid comparisons

* use isinstance instead of string comparison

* replace "is_convirtable_to_array" with previous logic

* remove throwing exceptions for other operations

* just a comment

* minor changes for efficiency

* optimize a utils function

* change the function name

* Update ACKNOWLEDGMENTS.md

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-29 06:52:30 -07:00
Cheng
913b19329c Add missing && when forwarding args (#925)
Without the && args would be copied and perfect forwarding won't work.
2024-03-29 06:48:29 -07:00
Awni Hannun
d8cb3128f6 bump (#924)
* bump

* fix version
2024-03-28 16:14:55 -07:00
Angelos Katharopoulos
5f9ba3019f Fix qmm_t for unaligned cases (#923) 2024-03-28 15:34:57 -07:00
Cheng
46caf0bef0 Remove unnecessary string copies (#891)
1. Use string_view instead of string when there is no need for copy.
2. Otherwise move string when possible.
2024-03-28 13:14:59 -07:00
Jack Mousseau
45f636e759 Add Metal debug option and capture functions (#707)
* Add Metal debug option and capture functions

* Add brief Metal debugger documentation

* doc nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-28 09:40:31 -07:00
Cheng
a7b404ff53 Use uintptr_t instead of size_t to store funtion id (#916)
Also does some small cleanup of the compile cache code.
2024-03-28 06:37:59 -07:00
Angelos Katharopoulos
c4fd0e5ede Fixes #918 bug in compile_tests (#919) 2024-03-27 22:37:37 -07:00
Cheng
bab5386306 Make ops aware of rvalues: astype/as_strided/copy/full (#895)
When compositing transforms lots of temporary of arrays will be created
and passed to next primitive, and by making ops accepting args by value
we can avoid lots of copies of temporary arrays.
2024-03-27 22:35:55 -07:00
Angelos Katharopoulos
aca7584635 Fix OOB read in qmv when non-divisible by blocksize (#917) 2024-03-27 22:18:35 -07:00
AmirHossein_Razlighi
d611251502 Support Chaining for some of functionalities of nn.Module (#885) (#897)
* add chaining support for some of the functionalities of "nn.Module"

* reformat

* change the return types

* remove return types

* add return type with forward referencing

* add tests for chaining

* add name to contributors

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update docstring

* update docstrings

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-27 19:58:29 -07:00
Cheng
f30b659291 Make MLX build on x64 macOS (#901)
The arm64 macbook pros are heavy and I usually care my intel one for
mobile, it would be nice if I can play with MLX on it.

To build with x64, user must pass `MLX_ENABLE_X64_MAC` to cmake:
CMAKE_ARGS='-DMLX_ENABLE_X64_MAC=ON' python setup.py
2024-03-27 06:14:29 -07:00
Cheng
90dfa43ff1 Don't use make_unique to create shared_ptr (#902)
The code compiled because shared_ptr's constructor actually accepts
unique_ptr.
2024-03-27 06:13:29 -07:00
Awni Hannun
dc175f08d3 Fix race in multi-stream eval (#911)
* maybe fix race

* comment
2024-03-26 16:36:36 -07:00
Angelos Katharopoulos
29221fa238 Implement vjps for some primitives in the fast namespace (#883)
* Implement rope vjp in terms of rope
* RMSNormVJP primitive and kernel
* Add LayerNormVJP primitive and kernel
2024-03-26 16:35:34 -07:00
Cheng
a789685c63 Remove duplicate defines of StreamOrDevice and is_big_endian (#892) 2024-03-26 15:15:11 -07:00
Jagrit Digani
240d10699c Implement negative padding in conv with slicing (#907)
* Implement negative padding with slicing

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni@apple.com>

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-26 14:59:19 -07:00
Jagrit Digani
925014b661 Fix multiblock sort limits (#906)
* Fix multiblock sort limits

* Fix metal validation error
2024-03-26 14:00:00 -07:00
Abdussamet Türker
5611e1a95e Fix unsqueeze with None (#899)
* Fix unsqueeze with None

* Clean unnecessary files
2024-03-26 13:59:44 -07:00
Awni Hannun
570f2bf29e pick up preivously set attributes (#905) 2024-03-26 11:19:59 -07:00
Angelos Katharopoulos
9948eddf11 Fix nan and improve speed for qvm (#903) 2024-03-26 10:41:45 -07:00
Luca Arnaboldi
a3ee03da01 Fixing random.normal for half-precision dtype #642 (#904)
* Fixing random.normal for half-precision dtype #642

* Update python/tests/test_random.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-26 09:58:27 -07:00
Cheng
28fcd2b519 Add missing && when forwarding args (#894)
Without the && args would be copied and perfect forwarding won't work.

Also add template utils to make sure the function only forwards array
and not vector<array>.
2024-03-25 14:55:54 -07:00
Jack Mousseau
8e686764ac Ensure shape dimensions are within supported integer range (#566) (#704)
* Ensure shape dimensions are within supported integer range (#566)

* fix build

* fix rebase bug

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 13:29:45 -07:00
Daniel Strobusch
479051ce1c add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 12:32:59 -07:00
Awni Hannun
bfb5bad4f0 patch (#893) 2024-03-24 21:03:59 -07:00
Awni Hannun
1e16331d9c post nanobind docs fixes and some updates (#889)
* post nanobind docs fixes and some updates

* one more doc nit

* fix for stubs and latex
2024-03-24 15:03:27 -07:00
Awni Hannun
be98f4ab6b Reduce a little overhead (#871)
* some small overhead improvements

* use result_type in rms_norm

* remove release force

* fix + use non-vector version

* revert compile change

* fix ops

* a little more overhead

* a little more cleanup and overhead
2024-03-22 17:29:36 -07:00
Angelos Katharopoulos
6ee1112f30 Fix copy donation and add partial rope (#881) 2024-03-22 17:28:26 -07:00
Jagrit Digani
8e5a5a1ccd Set item bug fix (#879)
* set item shaping bug fix

* Add extra tests
2024-03-22 12:11:17 -07:00
Angelos Katharopoulos
fcda3a0e66 Increase test tolerance for fast.layer_norm (#880) 2024-03-22 12:10:27 -07:00
Cheng
9663c22fe9 Do not store iostream in shared_ptr (#872)
There is no need to store iostream in shared_ptr, doing so adds the cost
of a heap allocation.
2024-03-22 06:54:45 -07:00
Cheng
f0ae00da12 Reduce implicit copies in make_array (#874)
1. Move shapes into outputs instead of copying them.
2. Pass primitive by const ref as it is always copied into outputs, which
   removes a copy when calling make_array.
2024-03-22 06:29:16 -07:00
Awni Hannun
44390bd3d0 Bump (#869)
* bump

* fix none in a few ops
2024-03-21 13:56:56 -07:00
Angelos Katharopoulos
2225374060 Adds mx.fast.layer_norm (#870) 2024-03-21 13:55:51 -07:00
nicolov
105d236889 Add vmap for SVD and inverse (#849) 2024-03-21 13:18:27 -07:00
Angelos Katharopoulos
53e6a9367c Use reshape and transpose for non-overlapping pooling windows (#867) 2024-03-21 10:21:03 -07:00
Chime Ogbuji
f5a1582fe8 Add minimum for cosine decay function (#859)
* Add minimum for cosine decay function

* Update python/mlx/optimizers/schedulers.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-21 07:33:29 -07:00
Awni Hannun
a54f06b16f Fast RMS Norm (#862)
* fast rmsnorm

* no rms gpu

* kernel

* fix shared mem

* looped rms and donation in softmax

* Make the squaring in float32 to avoid underflow

* Fix the default StreamOrDevice for rope and rms_norm in fast

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-21 07:20:54 -07:00
Cheng
4650d94d98 Add missing && in eval (#864)
Without the && args would be copied and perfect forwarding won't work.

To avoid eval calling itself recursively, the vector version of eval is
changed to take by value instead, which will save a copy of array when a
rvalue is passed.
2024-03-21 06:15:48 -07:00
Jagrit Digani
a5681ebc52 Update set item (#861)
* Update mlx_set_item to handle regular slices without expanding

* Refactor ellipsis handling

* Route mlx_set_item to slice_update where possible

* Update mlx_scatter_args_slice

* Don't route to gather if no array indices
2024-03-21 02:48:13 -07:00
Cheng
e849b3424a Do not use static constexpr in header (#863)
Doing so results in each compilation unit (.cpp file) having its own
copy of the variable, while inline constexpr makes sure there is only
one copy.
2024-03-20 21:28:05 -07:00
Jagrit Digani
b219d12a6b Check edge case handling in row reduce med kernel (#858) 2024-03-20 11:37:58 -07:00
Jagrit Digani
cec8661113 Add a SliceUpdate op and primitive (#850)
* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
2024-03-20 10:39:25 -07:00
Cheng
73a8c090e0 Pass shape and inputs by value in array's constructor (#853)
Since the shape and inputs are always saved as copy in ArrayDesc, we can
unify array's constructors to just take the arguments by value.

There are 2 cases:
1. When shape is a lvalue, it will be copied into array's constructor and
   then moved into ArrayDesc's member. So only 1 copy happens.
2. When shape is a rvalue, it will be moved into array's constructor and
   then moved into ArrayDesc's member. So no copy happens.

So having 1 constructor that takes by value is equivalent to having 2
constructors that const reference and rvalue separately.
2024-03-20 07:54:30 -07:00
Md. Rasel Mandol
db6796ac61 simple typo fille (#848) 2024-03-19 06:15:17 -07:00
Awni Hannun
9a8ee00246 Switch to nanobind (#839)
* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
2024-03-18 20:12:25 -07:00
Cheng
d39ed54f8e Some C++ code are not needed (#841)
1. Anonymous namespace means internal linkage, static keyword is not needed.
2. The default constructor of std::shared_ptr initializes the pointer to
   nullptr, you don't need to explicitly set it.
2024-03-18 17:04:10 -07:00
Awni Hannun
16546c70d8 No reshape rope (#838)
* no reshape rope

* no reshape rope
2024-03-18 17:03:07 -07:00
nicolov
eaba55c9bf Add matrix inversion primitive (#822) 2024-03-15 06:34:36 -07:00
Awni Hannun
19ec023256 vmap matmul and admm (#836) 2024-03-14 14:38:22 -07:00
Awni Hannun
63ab0ab580 version (#835) 2024-03-14 12:20:40 -07:00
Jagrit Digani
8dfc376c00 Strided reduce specialization for small reductions (#826)
* Add small column / general reduction specialization
2024-03-14 09:16:53 -07:00
Angelos Katharopoulos
1efee9db09 Add types and order in kernel name (#831) 2024-03-13 20:34:06 -07:00
Awni Hannun
43abc402d8 route to fallback (#828) 2024-03-13 19:56:04 -07:00
Angelos Katharopoulos
3f8b1668c4 Make reshape faster for row_contiguous cases (#829) 2024-03-13 16:22:03 -07:00
Angelos Katharopoulos
76c919b4ec NumberOfElements for shapeless compile and vmap fixes (#802) 2024-03-13 10:34:14 -07:00
Angelos Katharopoulos
29d0c10ee5 Reshape improvement (#818) 2024-03-12 17:54:31 -07:00
Jagrit Digani
5ad133f8bb No copy gems (#801)
* Enable collapsing batch dims in gemm
* Update gemm to only make copies when neither of the last 2 axes are contiguous
* Update addmm to support gemv shapes
* Update addmm to support irregular batch strides
* Update tests
2024-03-12 13:13:41 -07:00
nicolov
d0c544a868 Add SVD primitive (#809)
Add SVD op using Accelerate's LAPACK following
https://developer.apple.com/documentation/accelerate/
compressing_an_image_using_linear_algebra

Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
2024-03-12 12:30:11 -07:00
Daniel Falbel
ffb19df3c0 Fix docstring for correctly rendering (#820) 2024-03-12 11:46:44 -07:00
Awni Hannun
8b7532b9ab fix scatter (#821) 2024-03-12 11:42:07 -07:00
Awni Hannun
366478c560 fix modules with dict (#819) 2024-03-12 08:54:06 -07:00
Justin Deschenaux
8e5600022a Implement RNN, GRU, LSTM (#268)
* RNN base implementation

* Address comments+format

* nits in docs

* add tests for prb

* fix test

* add a couple tests

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-11 21:14:44 -07:00
Awni Hannun
0e95b64942 Fix bug in tape order during simplify (#816)
* fix bug in tape order during simplify

* properly fix compile

* last bug
2024-03-11 17:29:05 -07:00
nicolov
0ae22b915b Remove code duplication in reduce ops (#793)
* Remove code duplication in reduce ops

* Remove the unnecessary lambda

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-11 10:57:07 -07:00
Awni Hannun
7c441600fe Compile stride bug (#812)
* fix compile stride bug

* revert sdpa fix

* fix cpu

* fix bug with simplifying outputs
2024-03-11 06:31:31 -07:00
Awni Hannun
a4d290adb9 Remove depth traversal (#813)
* no depth traversal

* counter outside loop
2024-03-09 20:21:32 -08:00
Awni Hannun
28301807c2 Version bump and os error (#807) 2024-03-07 13:57:58 -08:00
Awni Hannun
74ed0974b3 Support 13.0+ with xcode 14.3 (#806)
* Support 13.0+ with xcode 14.3

* revert revert
2024-03-07 13:27:57 -08:00
Jagrit Digani
ec8a4864fa Fix SDPA kernel bug on Mac OS 13.3 SDK (#805)
* Move sdpa kernel to allocate tgp mem statically and allow macOS 13.3 SDK builds

* Style
2024-03-07 10:18:09 -08:00
Awni Hannun
b7588fd5d7 fix inplace to not make a shallow copy (#804) 2024-03-07 09:34:11 -08:00
Awni Hannun
f512b905c7 Minimum xcode / sdk (#800)
* minimum xcode /sdk

* try multiple xcode versions in CI

* update python

* metal validation for python tests
2024-03-07 08:19:43 -08:00
Awni Hannun
afd5274049 route to fallback for bfloat (#794) 2024-03-06 15:39:12 -08:00
Awni Hannun
1074674e32 Add a maximum graph depth (#797)
* add a maximum graph depth

* remember how to use C++
2024-03-06 15:39:00 -08:00
AlexCheema
7762e07fde Update function_transforms.rst (#796)
Fix typo in function_transforms.rst
2024-03-06 12:03:37 -08:00
Luca Arnaboldi
cbefd9129e Implementation of pickle, copy and deepcopy for Python arrays (#300 & #367). (#713)
* Implemented pickling and copy for Python arrays(#300 & #367)

* Fixing typos

* Pickle with NumPy arrays

* Pickle: workaround for bfloat16

* Revert "Pickle: workaround for bfloat16"

This reverts commit 25afe6bc09.

* Added an error when pickling bfloat16

* Update python/tests/test_array.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/tests/test_array.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/array.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/array.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* clang-format applied

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-06 08:02:41 -08:00
Angelos Katharopoulos
e39bebe13e Fix reshaping of empty arrays (#791) 2024-03-05 23:33:22 -08:00
Angelos Katharopoulos
14b4e51a7c Improved quantized matrix vector product (#786) 2024-03-05 17:32:19 -08:00
Awni Hannun
cbcf44a4ca Some fixes in cache / thread safety (#777)
* some fixes in cache / thread safety

* speed up no cache case

* fix opt test

* optimizer docs

* otpimizer docs

* fix adafactor

* fix adafactor
2024-03-05 13:30:50 -08:00
Awni Hannun
859ae15a54 Fix test (#785) 2024-03-04 23:02:27 -08:00
Brian Keene
0787724c44 Fast Inference SDPA op (#735)
* Fast Inference SDPA op

Implements metal shaders for:

o = mx.fast_inference_sdpa(queries, keys, values, scale, mask)

Supports fp16, fp32 dtypes; assumes d_k = 128.

Generic op support / prompt encoding supported via mlx primitives.
Metal implementation is for the inference use case only.

Majority of performance benefits appears to results from GQA & reduced
bandwidth requirements; there is approximate performance parity for the
MHA use case (from some measurements on M3 Max).

* Flush shared memory to zero before unprotected reads for (scores @ values)

* Move to fast:: namespace, address reviewer comments

... also attempt to revert formatter auto-change for files not relevant
to this change

* Shared memory flush to top of kernel

* Resolve compiler warnings

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update docstring per PR feedback

* Softmax in higher precision, ...

* route to fallback for more use cases - batch size > 1, head_dim other
  than 128, etc.
* Address linux build failure
* Address other reviewer comments

* Remove extraneous eval_cpu function per review

---------

Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: atila <atiorh@icloud.com>
2024-03-04 21:06:11 -08:00
Awni Hannun
7b463ffb07 Ios compile (#784)
* try to fix build for ios

* skip cpu compile

* fix namespace

* fix namespace

* Use CMake for platform specific cpu compile

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-04 20:02:26 -08:00
Jagrit Digani
6686e61ca4 Reduce update (#783)
* Split reduction files to reduce compile times

* Add small and medium axis size specializations for row reductions

* Add non-row-reduction options for small and med kernels
2024-03-04 19:09:51 -08:00
Awni Hannun
c096a77b9b revision bump (#778) 2024-03-04 13:41:53 -08:00
Awni Hannun
5121f028d9 nice tensordot for mlx c (#782) 2024-03-04 09:51:02 -08:00
Piotr Rybiec
6a665ea6ed Dilation for convolutional layers (#766)
* add dilation parameter to Conv1d layer

* space here too

* add conv1d dilation test

* add dilation parameter for Conv2d layer

* conv2d dilation test
2024-03-04 06:43:00 -08:00
Awni Hannun
bc06cb9ff6 Pickle + dtype fix for numpy conversion (#763)
* pickle + dtype fix for numpy conversion

* fix getattribute on Module base

* remove unused function

* fix tests

* add topk to ops

* fix doc
2024-03-02 06:09:29 -08:00
Angelos Katharopoulos
8e281c76c3 Fix the top-k op (#768) 2024-03-01 22:08:43 -08:00
Awni Hannun
d5964a2710 bindings for memory info (#761)
* bindings for memory info

* update api

* keep cache low if requested

* fix default

* nit in ops error
2024-03-01 19:51:58 -08:00
Ikko Eltociear Ashimine
cf3eb87e52 Fix typo in transforms.cpp (#764)
occuring -> occurring
2024-02-29 22:23:46 -08:00
Awni Hannun
ab3a466711 bump (#760) 2024-02-29 11:58:54 -08:00
Awni Hannun
4494970f47 avoid nested closures in module (#759) 2024-02-29 09:39:52 -08:00
Jagrit Digani
776c3d226d Convolution update (#651)
* Init steel conv and update Conv primitive

* Update slow CPU implementation to support flipping and input dilation winograd conv routing

Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-28 20:11:16 -08:00
Awni Hannun
f5f18b704f fix temporary bug (#752) 2024-02-27 17:44:39 -08:00
Awni Hannun
420ff2f331 Add back compiled function signatures and docstrings (#749)
* try to add back compiled function signatures and docstrings

* add indentation to docstring
2024-02-27 13:18:59 -08:00
Awni Hannun
56ba3ec40e fix cpu compile on older OS (#747) 2024-02-26 22:20:53 -08:00
Noah Kasmanoff
de3d2467a3 Update: Fast GeLU Approximation (#744)
* add: fast gelu approx

* fix docs

* Update gelu_fast_approx function documentation

* Update python/mlx/nn/layers/activations.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix: test gelu

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-26 21:08:50 -08:00
Awni Hannun
fe1dabf272 Fix compile with non standard types (#745)
* refactor tree utils

* fix compile + tree code refactor

* Add an extra test

* add a few missing activations to docs

* hash structure

* Encode the full argument structure

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-26 19:28:53 -08:00
Hinrik Snær Guðmundsson
08226ab491 added atleast *args input support (#710)
* added atleast list(array) input support

* function overloading implemented

* Refactoring

* fixed formatting

* removed pos_only
2024-02-26 11:17:59 -08:00
Chime Ogbuji
3b661b7394 Add linear warmup and schedule joining for use with existing schedules (#721)
* Add linear warmup to schedules for use with existing schedules

* Changed parameters for simplicity of most common case (0 initial value)

* Added ScheduleJoiner and updated documentation

* ScheduleJoiner -> join_schedules (ala optax #)

* black compliance

* Different evaluation of schedules

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-26 07:28:48 -08:00
Awni Hannun
e6418781ab Fix logsumexp edge case (#740)
* fix logsumexp

* fix inf constant

* also fix power grad

* fix ternary dispatch
2024-02-25 08:39:55 -08:00
Awni Hannun
ac02cf33bd Fix some issues using MLX in C++ (#739)
* fix preamble build

* fix some issues with using MLX as a dep in C++
2024-02-24 22:20:57 -08:00
Gabrijel Boduljak
22364c40b7 Upsample2d (#414)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-23 09:55:04 -08:00
Noah Farr
d729a1991b Fix arange with inf step (#686)
* Fix case for step=inf in arange and add inf check for start/stop

* Add test cases for arange

* Update ops.cpp to include climits header

* Fix arange

* Fix formatting

* Refactor

* Add missing include
2024-02-23 06:18:15 -08:00
Rifur13
126c9869c8 Implement the 'where' primitive for conditional selection (#664) 2024-02-22 15:10:48 -08:00
Angelos Katharopoulos
ad4a45e615 Fix the release builds in CI (#729) 2024-02-22 14:09:13 -08:00
Awni Hannun
04fc896016 version bump (#727) 2024-02-22 11:54:17 -08:00
Jagrit Digani
884b4ed43b Fix threadgroup memory in arg reduce (#723) 2024-02-21 19:42:16 -08:00
Vijay Krish
972d9a3aea Up to 10x faster scatter. (#709)
* Faster scatter.

Add specialization for 1-d index tensors.

* Address review comments.

- Check for row contiguity of index, update tensors
  instead of checking strides.
- Add support for 1d specialization with col contiguous update
  tensor, along with a test.

* Nit1

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Nit2

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-21 11:09:30 -08:00
Angelos Katharopoulos
7dcdd88e27 Change the logo and add a dark option (#716) 2024-02-20 10:57:02 -08:00
Awni Hannun
8120a3b65c link to other APIs (#715)
* link to other APIs

* remove sec
2024-02-20 09:54:49 -08:00
Awni Hannun
5798256fcf Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs

* update compile benchmark

* default compile a few activations

* buffer donation

* bugfix

* shapeless fix

* update tests to work for cpu and gpu fusion

* test kwargs

* add kwargs to compile

* Recompile when python arguments change

* no compile for tanh

* some constant tests

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-19 21:43:54 -08:00
Awni Hannun
d0fda82595 fix tolist for half types (#702) 2024-02-19 09:44:27 -08:00
Hinrik Snær Guðmundsson
f883fcede0 Added support for atleast_1d, atleast_2d, atleast_3d (#694) 2024-02-19 09:40:52 -08:00
Diogo
e1bdf6a8d9 discover doctests in cmake (#703) 2024-02-19 07:03:56 -08:00
Awni Hannun
1a4f4c5ea6 Refactor CPU compile preamble (#708)
* refactor cpu preamble

* fix include order

* fix some issues'

* fixes for linux

* try to fix includes

* add back warning suppression

* more linux fixes
2024-02-19 06:12:53 -08:00
Jack Mousseau
0925af43b0 Remove unused variables (#706) 2024-02-18 12:50:10 -08:00
Awni Hannun
dc937b8ed3 CPU compile (#691)
* build and load shared object for cpu compile

* nits

* cpu compile tests pass

* cpu compile tests pass

* fix preamble for g++

* donation

* fix gpu buffer donation

* reuse prebuilt libraries

* faster contiguity conditoins

* fix test

* rid compiler warning

* fast erf

* Fix float16 for compile and add more types to cpu compile

* Remove a forgotten comment

* use cached libs

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-17 06:54:32 -08:00
Awni Hannun
c3965fc5ee Separate fast ops and primitives (#699) 2024-02-16 19:16:39 -08:00
Awni Hannun
bf7cd29970 version bump (#698) 2024-02-16 08:44:08 -08:00
Nripesh Niketan
a000d2288c feat: update black pre-commit hook to 24.2.0 (#696) 2024-02-16 06:01:59 -08:00
Mike Drob
165abf0e4c Auto-run PRs from contributors (#692) 2024-02-15 17:30:35 -08:00
Srimukh Sripada
818cda16bc Support LR schedulers (#334)
* Add a few LR schedulers

* Move parents's constructor call to the top

* Fix docstring

* refactor optimizers into two files

* add docs

* nit

* Fix Callable type annotation for python 3.8

---------

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-15 11:26:20 -08:00
toji
85143fecdd improved error msg for invalid axis(mx.split) (#685)
* improved error msg for invalid axis(`mx.split`)

* Apply suggestions from code review

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fixed formatting issue

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-15 07:25:38 -08:00
Diogo
35431a4ac8 Adds device context manager (#679) 2024-02-14 14:14:58 -08:00
Awni Hannun
ccf1645995 Custom primitive + RoPE fat op (#676)
* extensions start

* rope custom op

* fix build

* docs + rope benchmark

* fix test

* Add a Metal kernel for RoPE

* Fix position of traditional

* transform tests

* Move rope computation to float and fix tests

* Fix the test and a typo

* change to fast

* fix no metal build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 14:04:25 -08:00
Jagrit Digani
1a48713d32 Update gather and scatter to not use Argument Encoder (#683)
* Replace argument encoder usage for gather and scatter

* Use constant address space for shapes and strides

* Split gather and scatter to improve compile times

* Enable the GPU tests

* Update the CI config

* Fix scatter dispatch for scalar indices

* Remove arg encoder utils

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 13:42:13 -08:00
Awni Hannun
1eb04aa23f Fix empty array construction in cpp (#684) 2024-02-13 23:34:17 -08:00
Noah Farr
0c65517e91 Return empty array when repeats is 0 in mx.repeat (#681)
* Return empty array when repeats is 0

* Add test case for repeats = 0
2024-02-13 17:49:31 -08:00
Vijay Krish
2fdc2462c3 Faster gather and scatter. (#682)
Reduce unnecessary integer ops, especially since
there kernels are integer bound.

Increase number of iterations for benchmarks for
better smoothing.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-13 17:47:41 -08:00
Hinrik Snær Guðmundsson
be6e9d6a9f Fixed wording in extensions.rst (#678)
changed "learn how add" -> "learn how to add"
2024-02-13 08:39:02 -08:00
Gabrijel Boduljak
e54cbb7ba6 Pooling layers (#357)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-12 22:08:13 -08:00
Angelos Katharopoulos
40c108766b Quantized matmul fix (#677)
* Fix qmv for small or unaligned matrices

* Fix qmm
2024-02-12 18:54:21 -08:00
Mike Drob
4cc70290f7 PR Builder Workflow (#659) 2024-02-12 17:47:21 -08:00
Awni Hannun
74caa68d02 nit in readme (#675) 2024-02-12 12:25:04 -08:00
Awni Hannun
3756381358 Faster bfloat quantized mat-vec and vec-mat (#663) 2024-02-11 21:53:16 -08:00
Awni Hannun
d12573daa6 quote file name (#670) 2024-02-11 10:33:30 -08:00
Nripesh Niketan
0dbc4c7547 feat: Update pre-commit-config.yaml (#667) 2024-02-11 06:08:20 -08:00
Vijay Krish
06072601ce Scatter optimization : Eliminate 64b integer divide. (#662)
Launch 2D grid to eliminate divide and mod in device code,
since 64b integer division is very expensive.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-10 08:49:51 -08:00
Angelos Katharopoulos
11d2c8f7a1 Linux build for CI of other packages (#660) 2024-02-09 18:17:04 -08:00
Awni Hannun
7f3f8d8f8d Fix the softmax fix (#661) 2024-02-09 17:02:13 -08:00
Awni Hannun
b96be943dc bug fix (#658) 2024-02-09 16:50:45 -08:00
Abdussamet Türker
b670485185 Remainder negative numerator bug fixed (#641)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-09 16:49:14 -08:00
Diogo
b57bd0488d Metadata support for safetensors (#639)
* metadata support for safetensors

* aliases making it alittle more readable

* addressing comments

* python binding tests
2024-02-08 19:33:15 -08:00
Angelos Katharopoulos
221f8d3fc2 Bump the version to 0.2 (#656) 2024-02-08 11:27:12 -08:00
Awni Hannun
5c03efaf29 Compile docs (#653)
* compile docs

* docs nits + comments
2024-02-08 11:21:50 -08:00
LeonEricsson
7dccd42133 updated calls to use loc &scale (#643) 2024-02-08 09:01:59 -08:00
Awni Hannun
1b97b2958b Compile with capture (#629)
* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-07 17:29:22 -08:00
Awni Hannun
e5e816a5ef fix sequential with empty modules at end (#647) 2024-02-07 13:22:27 -08:00
Angelos Katharopoulos
28eac18571 Kernel generation (#614)
Generate reusable element-wise kernels given a computation graph.
2024-02-07 13:15:59 -08:00
Noah Farr
5fd11c347d Add loc and scale to random.normal (#638)
* Add loc and scale to random.normal

* Add tests for loc and scale for random.normal

* Run pre-commit hooks

* Fix code review
2024-02-07 11:49:59 -08:00
Aryan Gupta
ef73393a19 Feat: Add weights argument in BCE Loss and tests (#620) 2024-02-07 09:39:52 -08:00
Angelos Katharopoulos
ea406d5e33 CI change (#645)
* CI update

* Skip large binary test for now

* Upgrade pip

* Add proper env variable skipping

* Update the CI

* Fix workflow name

* Set the low memory flag for the tests

* Change build process

* Add pip upgrade

* Use a venv

* Add a missing env activate

* Add setuptools

* Add twine upload back

* Re-enable automatic release builds
2024-02-07 06:04:34 -08:00
Awni Hannun
146bd69470 Skip compile when transforming (#635)
* skip compile when transforming

* simplify message
2024-02-05 21:28:37 -08:00
Jagrit Digani
316ff490b3 Remove masks from BlockLoader and clear out load case for invalid thread (#634) 2024-02-05 16:00:17 -08:00
Awni Hannun
d40a04f8dc minor fixes (#631)
* minor fixes

* var with ddof >= nelements
2024-02-05 13:27:49 -08:00
Awni Hannun
d75ae52ecd Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
2024-02-05 06:51:22 -08:00
Avikant Srivastava
31fea3758e feat: enhancement of the error message for mlx.core.mean (#608)
* add error message
2024-02-05 01:21:49 -08:00
Awni Hannun
e319383ef9 Faster gather (#626)
* faster gather

* update copyright
2024-02-04 17:25:44 -08:00
Awni Hannun
5c3ac52dd7 fix test (#627) 2024-02-04 16:18:03 -08:00
David Koski
ebfd3618b0 fixes for building and running on iOS (#619)
* fixes for building and running on iOS

* per suggestion just use Accelerate
2024-02-04 12:29:17 -08:00
Avikant Srivastava
11a9fd40f0 fix: handle linspace function when num is 1 (#602)
* fix: handle linspace function when num is 1

* add comment

* fix test case

* remove breakpoint
2024-02-04 11:03:49 -08:00
Daniel Strobusch
4fd2fb84a6 make python array SupportsAbs conform (like numpy) (#624) 2024-02-04 09:31:02 -08:00
Daniel Strobusch
9852af1a19 fix "shape" docstring. (#623) 2024-02-04 09:21:22 -08:00
minghuaw
16750f3c51 Fix typo in CMakeLists.txt (#616) 2024-02-03 05:59:26 -08:00
Awni Hannun
95b5fb8245 minor changes (#613) 2024-02-02 11:48:35 -08:00
AtomicVar
83f63f2184 Add Margin Ranking Loss (#536) 2024-02-02 10:57:31 -08:00
Awni Hannun
cb6156d35d Fix eval in trace bugs (#612)
* Fix eval in trace bugs

* comment nit
2024-02-02 09:57:12 -08:00
Piotr Rybiec
506d43035c typo fix (#607) 2024-02-01 17:39:55 -08:00
Angelos Katharopoulos
36cff34701 Bump the version (#604) 2024-02-01 11:41:38 -08:00
Awni Hannun
e88e474fd1 Reduce vmap + some fixes (#601) 2024-02-01 11:30:28 -08:00
David Koski
601c6d6aa8 Fix for AdaDelta (#603)
- state was being read from parameter "s"
- but being stored in parameter "u"
2024-02-01 09:56:27 -08:00
Angelos Katharopoulos
ba8d6bf365 Change the transformer to norm_first by default (#599) 2024-01-31 12:55:30 -08:00
Sugato Ray
4a5f3b21bb Add py.typed to support PEP-561 (type-hinting) for mlx (#588)
* Add `py.typed` to support PEP-561 (type-hinting)

This adds support for type-hinting information as laid in [PEP-561](https://peps.python.org/pep-0561/).

* add py.typed to MANIFEST.in
2024-01-31 12:05:42 -08:00
Vijay Krish
fcc5ac1c64 Add GPU support for uint64/int64 reductions (#569) 2024-01-31 11:18:04 -08:00
nathan
bad67fec37 Added TeX line breaks to mlx.optimizers.Lion docstring (#595)
Fixes the "misplaced &" MathJax error in documentation.
2024-01-30 19:37:34 -08:00
Angelos Katharopoulos
199aebcf77 Change the variance computation (#319) 2024-01-30 19:28:56 -08:00
Angelos Katharopoulos
0de5988f92 Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00
Jacket
143e2690d5 Fix SGD implementation (#473) 2024-01-30 15:50:46 -08:00
Jagrit Digani
375446453e Update Compute Pipeline Creation API (#581)
* Add option to specialize metal functions on function constants
* Update Compute Pipeline Creation API
* Add options to make libraries from source and stitching
* Update function specialization name options
2024-01-30 15:42:36 -08:00
Angelos Katharopoulos
1895d34c20 Fix log1p with inf inputs (#592) 2024-01-30 14:02:50 -08:00
Awni Hannun
09b9275027 Make shape a tuple (#591)
* shape tuple

* also remove simplify from docs

* rebase
2024-01-30 13:11:01 -08:00
Andre Slavescu
d3a9005454 Softshrink mapping + op (#552)
* Added Softshrink mapping + op

* formatting

* docs + nits in docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-30 12:56:28 -08:00
Jacket
3f7aba8498 Implement diagonal operator (#562)
* Implement diagonal operator

This implements mx.diagonal in operator level, inspired by
@ManishAradwad.

* added `mx.diag` with tests

* corrected few things

* nits in bindings

* updates to diag

---------

Co-authored-by: ManishAradwad <manisharadwad@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-30 09:45:48 -08:00
Angelos Katharopoulos
65d0b8df9f Fix binary op dispatch (#584) 2024-01-29 19:36:17 -08:00
Awni Hannun
3c2f192345 Propagate nans in binary ops (#579)
* propagate nans in binary ops

* handle empty matmul

* cpu minimum/maximum propagate nan

* benchmark maximum

* add min as well

* throw on negative indices with full

* verbose on linux

* fix matmul for zero K
2024-01-29 11:19:38 -08:00
Angelos Katharopoulos
37d98ba6ff No gil eval (#565) 2024-01-26 22:03:52 -08:00
Awni Hannun
8993382aaa Buffer Donation (#519)
* buffer donation

* fix to move shared pointer

* format

* gpu in place for copy and binary

* revert ops test

* cpu in place

* a little cleanup

* remove useless bench
2024-01-26 16:30:33 -08:00
Awni Hannun
07f35c9d8a Fix a few issues: docs for flatten, erf, dequantize validation (#560)
* doc flatten

* erf doc

* check values for dequantize

* format
2024-01-26 15:16:46 -08:00
Jagrit Digani
bf17ab5002 Add more checks and clearer error messages to conv operations (#563)
* Add more checks and clearer error messages to conv operations
2024-01-26 15:13:26 -08:00
Awni Hannun
8fa6b322b9 Compile front-end (#476)
* fix tests for linux

* make a move on compile

* basic compile scaffold works

* compile binding

* clean

* fix

* fix grad, more tests

* basic python tests

* fix segfault on python exit

* compile works with python closures

* fix test

* fix python globals bug, and erase

* simplify

* more cpp tests

* bug fix with move function and compile at exit

* simplify inputs also

* enable and disable compiler

* remove simplify

* simplify tests use compile now

* fix multi-output with compile

* clear output tree from cache when function goes out of scope

* ../python/src/transforms.cpp

* remove closure capture

* comments
2024-01-26 13:45:30 -08:00
David Koski
874b739f3c Fix cache key in RoPE (#561) 2024-01-26 13:10:02 -08:00
taher
077c1ee64a QR factorization (#310)
* add qr factorization

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-26 09:27:31 -08:00
Rifur13
2463496471 [Fix] mx.allclose bug with infinite values (#539)
* Added isclose op and fixed comparison with inf values

* Added 'equal_nan' to match numpy

* format

* Add test

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Addressed CR comments

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-25 20:47:06 -08:00
Angelos Katharopoulos
87b7fa9ba2 Bump the version (#554) 2024-01-25 11:01:05 -08:00
Danilo Peixoto
624065c074 Fix package installation for CI (#521)
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-01-25 09:43:34 -08:00
Awni Hannun
f27ec5e097 More helpful error message in vjp transform + concate bug (#543)
* more helpful message in vjp transform

* fix concatenate on mismatch dims

* typo

* typo
2024-01-24 09:58:33 -08:00
Awni Hannun
f30e63353a Minor updates to address a few issues (#537)
* docs on arg indices return type

* arange with nan

* undo isort
2024-01-23 22:24:41 -08:00
Juarez Bochi
4fe2fa2a64 GGUF: Avoid dequantization when format is compatible (#426)
* GGUF: Don't dequantize q4_1

* Fix weight order. First in low bits

* Add unpacking for q4_0

* Don't dequantize q8_0

* rebase quants and split file

* don't quantize every weight

* reapply patch

* error handling

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 15:43:57 -08:00
Hazem Essam
37fc9db82c Added Adafactor (#415)
* Added adafactor

* Added Adafactor and ran pre-commit

* modified operations

* Added docstrings

* Switched two ops to fix a bug

* added underscore for internal functions and removed the plus sign in the last return statment

* Removed parameter rms from the optimizer state because its not needed

* Added simple MNIST test for Adafactor and temporary training log

* remove test files

* nits in docs

* comment nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 15:11:27 -08:00
AtomicVar
755dcf6137 Enable cross_entropy loss to handle dense targets (#517)
* Enable cross_entropy loss to handle dense targets

Dense targets means probabilities or one-hot encodings.

* better shape check of weights

* nits in docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 12:17:22 -08:00
LeonEricsson
6b4b30e3fc Common neural network initializers nn.initializers (#456)
* initial commit: constant, normal, uniform

* identity, glorot and he initializers

* docstrings

* rm file

* nits

* nits

* nits

* testing suite

* docs

* nits in docs

* more docs

* remove unused template

* rename packakge to nn.innit

* docs, receptive field

* more docs

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 06:47:20 -08:00
Awni Hannun
86e0c79467 remove stale benchmarks (#527) 2024-01-22 22:17:58 -08:00
Awni Hannun
98c37d3a22 use axes in tensordot (#525) 2024-01-22 21:17:00 -08:00
Sugato Ray
f326dd8334 Update README.md (#524)
Add conda install option in docs.
2024-01-22 20:53:54 -08:00
Jagrit Digani
6d3bee3364 Fix oob reads in gemv kernel (#523) 2024-01-22 12:06:04 -08:00
Danilo Peixoto
ecb174ca9d Type annotations for mlx.core module (#512) 2024-01-21 12:53:12 -08:00
Awni Hannun
7a34e46677 Quantize with groups of 32 (#511)
* allow quantize with group sizes of 32

* missing cpu dispatch

* remove print

* Fix qvm for group_size 32

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-21 06:19:05 -08:00
Nripesh Niketan
92c22c1ea3 feat: Update isort version to 5.13.2 (#514) 2024-01-21 06:11:48 -08:00
Awni Hannun
d52383367a format (#510) 2024-01-20 10:33:46 -08:00
Arda Orçun
363d3add6d Add ValuError message for Adamax (#508)
* ValuError message added

* beta errors added

* some corrections and testing

* Learning rate limitation deleted
2024-01-20 07:56:15 -08:00
Awni Hannun
b207c2c86b Power VJP fix for 0 (#505) 2024-01-20 01:17:40 -08:00
Awni Hannun
6bf779e72b fix array from list for > 32 bit types (#501) 2024-01-19 15:49:25 -08:00
Juarez Bochi
ddf50113c5 GGUF: Load and save metadata (#446)
* gguf metadata
---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-19 14:06:05 -08:00
Arda Orçun
6589c869d6 Added MSE message (#500)
* Added MSE message

* changed wrong line.

* Update examples/python/linear_regression.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-01-19 06:27:50 -08:00
Anchen
f6feb61f92 feat: add support for saving safetensors in the save_weights (#497)
* feat: add save safetensors support in module save_weights

* chore: checking missing changes

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* chore: update docstring for load_weights

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-01-19 06:19:33 -08:00
Awni Hannun
c4ec836523 fix isinf for integer types (#494) 2024-01-19 05:31:10 -08:00
AtomicVar
550d4bf7c0 Update binary_cross_entropy function to handle both logits and probabilities (#492) 2024-01-18 19:22:23 -08:00
Awni Hannun
f6e911ced0 version bump (#490)
* version bump

* Fix the dev version string

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-18 12:00:24 -08:00
Awni Hannun
3d99a8d31d Fix format / build (#489) 2024-01-18 10:01:59 -08:00
Ethan
a749a91c75 Support disable metal buffer cache to prevent performance degradation caused by large memory caching (#390)
* support disable metal buffer cache, due to large unused memory buffered when llm generated long context tokens

* Run format and add "cache_enabled" feature tests
2024-01-18 08:33:34 -08:00
toji
49a52610b7 Added formatter structure and a boolean value formatter (#354)
* added formatter structure and a boolean value formatter

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-18 07:49:41 -08:00
AtomicVar
d1fef34138 Add Gaussian NLL loss function (#477)
* Add Gaussian NLL loss function

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-18 06:44:44 -08:00
Angelos Katharopoulos
9c111f176d Fix split optimization for array iterator (#484) 2024-01-18 05:50:25 -08:00
Awni Hannun
78e5f2d17d usage doc for function transformations (#481) 2024-01-17 17:10:53 -08:00
Angelos Katharopoulos
90c234b7ac Fix round to round half-cases to even (#482) 2024-01-17 15:27:23 -08:00
Angelos Katharopoulos
135fd796d2 Fix detach for multi-output primitives (#480) 2024-01-17 14:08:07 -08:00
Jagrit Digani
78102a47ad Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance 
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition 
* Update tests and benchmarks as needed
2024-01-17 12:42:39 -08:00
Diogo
556cdf0e06 Resolves build issues with the extension example (#419)
* resolved extension build issues and added test to ci

* missing gguflib

* rebased

* force mlx install from fix branch

* linux build issue

* point to git install and comment out ci tests
2024-01-17 12:07:05 -08:00
Awni Hannun
275db7221a Command buffer reports errors (#479)
* command buffer reports errors

* typo

* simplify
2024-01-17 11:53:30 -08:00
AtomicVar
4a9012cba0 Sort some APIs docs by names (a-z) (#472) 2024-01-16 19:37:50 -08:00
Awni Hannun
a2bf7693dd Primitive's VJP takes outputs as input (#475)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-16 19:03:53 -08:00
Angelos Katharopoulos
d8fabaa12b Split multi output (#461)
* Multi-output split primitive
* Add the multi-output split to the ArrayIterator
* Add some grad tests for split
2024-01-16 13:33:55 -08:00
Avikant Srivastava
4e290d282f feat: add time based seed to random.h (#457)
* random seed from time

* fix: chrono

* refactor: snake case
2024-01-16 07:32:28 -08:00
Yashraj Singh
e72458a3fa implemented isposinf and isneginf in one PR (#470)
* ran precommit

* updated docs
2024-01-16 06:48:07 -08:00
Awni Hannun
a2ffea683a Fix eye for larger matrices (#463)
* fix eye
* fix scatter for <32bit (non native atomic) types
* fix int overflow
2024-01-16 00:51:24 -08:00
Angelos Katharopoulos
c15fe3e61b Allow arbitrary first dimension in quantization kernels. (#458)
* Allow arbitrary first dim on qmm_t and qmv
* Allow arbitrary first dim on qmm and qvm
* Specialized aligned vs unaligned case
* Add more checks for valid quantizations
2024-01-16 00:46:21 -08:00
Tristan Bilot
f44c132f4a Add scatter_min VJP (#462) 2024-01-16 00:37:40 -08:00
Matthew Ernst
92a2fdd577 Adds isinf (#445)
* adds isinf

Signed-off-by: matthewfernst <matthew.f.ernst@gmail.com>

* use stream + nits

* typo

---------

Signed-off-by: matthewfernst <matthew.f.ernst@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-15 19:50:44 -08:00
Tristan Bilot
6022d4129e scatter_max vjp + bindings + tests (#431)
Co-authored-by: DjamelMesbah <djamel.mesbah@adservio.fr>
2024-01-14 14:12:15 -08:00
Awni Hannun
4bc446be08 Use a dummy primitive to only sync with one output (#453)
* Use a dummy primitive to only sync with one output
* Fix test and choose stream with slight care
2024-01-14 14:09:40 -08:00
Awni Hannun
41cc7bdfdb Fix stub generation, change graph exporting for arrows to go to outputs (#455) 2024-01-14 14:06:16 -08:00
Awni Hannun
6e81c3e164 Sync only with outputs we need to sync with (#447) 2024-01-13 01:47:25 -08:00
Diogo
2e29d0815b Add tile op (#438) 2024-01-12 23:03:16 -08:00
Awni Hannun
1b71487e1f docs (#444) 2024-01-12 13:34:16 -08:00
Ayush Shridhar
1416e7b664 Add isnan (#423) 2024-01-12 11:16:48 -08:00
davidkoski
29081204d1 array.swapaxes should point to swapaxes free function (#441) 2024-01-12 11:06:16 -08:00
Angelos Katharopoulos
006d01ba42 Fix packaging of gguflib (#435) 2024-01-11 13:56:03 -08:00
Awni Hannun
46dc24d835 version bump (#433) 2024-01-11 12:29:35 -08:00
Awni Hannun
c9934fe8a4 Metal validation (#432)
* tests clear metal validation

* add cpp test with metal validation to circleci

* nit
2024-01-11 11:57:24 -08:00
Avikant Srivastava
975e265f74 feat: Add numpy constants (#428)
* add numpy constants

* feat: add unittests

* add newaxis

* add test for newaxis transformation

* refactor
2024-01-11 06:47:29 -08:00
Awni Hannun
c92a134b0d more docs (#421)
* more docs

* fix link

* nits + comments
2024-01-10 14:04:12 -08:00
Awni Hannun
3b4f066dac Correct types for vjp + tests (#418)
* correct types for vjp + tests

* fix build + comment
2024-01-10 13:32:37 -08:00
Juarez Bochi
b7f905787e GGUF support (#350)
* Initial GGUF support for tensor fields.

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-10 13:22:48 -08:00
Chunyang Wen
e3e933c6bc Add type hint for Module (#412) 2024-01-10 11:23:42 -08:00
Awni Hannun
1d90a76d63 in place ops behave in place, fix some overloads (#411) 2024-01-09 16:05:38 -08:00
Angelos Katharopoulos
961435a243 Scatter vjp (#394)
* Add a first scatter vjp
* Implement the scatter_add vjp
* Add array.at to implement user friendly scatters
2024-01-09 13:36:51 -08:00
Awni Hannun
e9ca65c939 Fix BN stats to not expand shape (#409)
* fix BN stats to not expand shape

* nit
2024-01-09 11:54:51 -08:00
Dwayne Robinson
753867123d Fix data_types.rst uint64 (#406)
uint64 correctly says 8 bytes, but the description is copy pasta.
2024-01-09 06:40:10 -08:00
Awni Hannun
f099ebe535 Multi output primitives (#330)
* Multi-output primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-08 16:39:08 -08:00
BigsnarfDude
f45f70f133 Update mlx-example link for llms llama in llama-inference.rst (#405) 2024-01-08 16:29:53 -08:00
YUN, Junwoo
0b8aeddac6 Additoinal losses (#336)
* cosine similarity loss

---------

Co-authored-by: Awni Hannun <awni@apple.com>

* Docstring nits
2024-01-08 14:01:13 -08:00
Jagrit Digani
432ee5650b Update cpp tests with allclose and doctest::Approx for numerical tolerance (#401) 2024-01-08 09:35:05 -08:00
Nripesh Niketan
73321b8097 feat: add logicalAnd and logicalOR (#386)
* feat: add logicalAnd and logicalOR

* run pre-commit

* Refactor logical_and and logical_or functions

* Add acknowledgement

* Add logical AND and logical OR operators

* Refactor logical_and and logical_or functions

* Add support for logical operators on bool arrays

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Add logical AND and OR operators for arrays and scalars

* Refactor vjp and jvp methods in primitives.cpp

* Add overloaded operators for logical AND and OR

* format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-08 07:00:05 -08:00
Hazem Essam
022a944367 Added GLU activation function and Gated activation function (#329)
* Added GLU activation function and gated activation function

* Ran pre-commit

* Ran pre commit

* Removed old sigmoid implementation to match with main

* Removed gated activation from __init__.py

* Removed unused test cases

* Removed unused imports

* format / docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-08 06:13:16 -08:00
Chris Costes
026ef9aae4 Update Install Instructions (#397)
* Add note to install instructions for building from source to ensure native arm64 environment and tools.

* Add troubleshooting info.

* remove cmake bits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-07 19:11:04 -08:00
Angelos Katharopoulos
a611b0bc82 Removes the retain_graph flag (#385)
* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
2024-01-07 15:16:51 -08:00
Diogo
449b43762e Add inner / outer op (#348)
* inner / outer impl

* python tests

* ops list and ack

* updated descriptions

* use test helper

* removed dtype check and flatten outer to 1-D

* updated docs

* just use the reshape to flatten
2024-01-07 09:01:09 -08:00
Angelos Katharopoulos
6ea6b4258d Fix style check (#395) 2024-01-07 05:54:58 -08:00
Anchen
48f6ca8c3a Add theta cache for Rope and mask cache for ALiBi (#375) 2024-01-07 00:22:58 -08:00
Awni Hannun
c6d2878c1a safely divide for 0 size inputs (#388) 2024-01-07 00:19:54 -08:00
Awni Hannun
b34bf5d52b fix saving for non-contiguous arrays (#389) 2024-01-06 12:44:02 -08:00
Angelos Katharopoulos
608bd43604 Move the matmul type check in the op (#384) 2024-01-05 19:10:13 -08:00
Angelos Katharopoulos
4c48f6460d Fix segfault from buffer protocol and tests (#383)
* Fix segfault from buffer protocol and tests

* Fix tf test
2024-01-05 18:17:44 -08:00
Daniel Strobusch
1331fa19f6 Make array conform to the Python Buffer Protocol (#323) 2024-01-05 15:58:33 -08:00
Daniel Strobusch
dfdb284e16 make behaviour of dtype arguments consistent and compliant to numpy (#379)
All functions that take an optional dtype should

* have a default dtype visible in the generated docs (accomplished via `"dtype"_a = std::optional{float32}`)
* behave identical when `dtype=None` or no dtype is passed

This important when passing kw args down from a numpy function like:

```
def f(x, dtype=None):
  mx.random.uniform(dtype=dtype)
  # ...
```

NumPy functions behave like this.

It also fixes a minor bug in `tri`: #378

Closes #378
2024-01-05 09:37:46 -08:00
mutexuan
d8f41a5c0f support python mlx.array creation from list of mlx.array's (#325)
* support python mlx.array creation from list of mlx.array's

* include bfloat16 in UT

* refactor so that sub array made of all python primitive types gets initialized by fill_vector

* address PR comment: arr.shape().size() -> arr.ndim()

* address PR comment: get back Dtype constness and let stack to handle type promotions automatically
2024-01-04 18:53:33 -08:00
Awni Hannun
b9e415d19c bump pre commit and fix format (#373) 2024-01-04 16:28:52 -08:00
davidkoski
c82a8cc526 move all ObjC (via metal-cpp) interaction until post static initializers (#370)
* move all ObjC (via metal-cpp) interaction until post static initializers

- metal-cpp relies on static initializers to cache class and selector pointers
- code in mlx was using metal-cpp to set up NSAutoreleasePools during its own static init time
- but this code was silently failing as the class and selector pointers from metal-cpp were still nil

- defer the creation of NSAutoreleasePools until after static init time
- ensure that we have coverage where autorelease pools are needed

* Update device.cpp

remove commented code

* Update device.cpp

remove commented out code

* Update scheduler.h

update comment

* per discussion use the pool inside the task() -- this will be metal only, not needed for cpu

* Update allocator.cpp

move pool to release/alloc area
2024-01-04 16:12:00 -08:00
Angelos Katharopoulos
75dc537e44 Fix the sigmoid module (#371) 2024-01-04 13:16:36 -08:00
Awni Hannun
cf88db44b5 revert copy (#366) 2024-01-04 10:43:29 -08:00
Chunyang Wen
16856a0160 Remove useless pass (#364)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2024-01-04 06:34:01 -08:00
Awni Hannun
d752f8e142 Fix CI (#359)
* fix ci

* check for linux for fp16
2024-01-04 06:33:08 -08:00
toji
d2467c320d Added support for python copy (#335)
* Added support for python copy

* precommit changes

* removed `_compiled_call_impl` line

* added tests and suggested changes

* ACK changes
2024-01-03 20:59:40 -08:00
Diogo
0d31128a44 use union instead of | (#358) 2024-01-03 19:33:19 -08:00
Diogo
1ac18eac20 simple numpy helper for tests (#352) 2024-01-03 19:19:19 -08:00
Awni Hannun
526466dd09 version bump (#355)
* version bump

* one more
2024-01-03 14:48:24 -08:00
Angelos Katharopoulos
e7f5059fe4 Support for quantized matmul with w and w^T (#349)
* Add the metal qvm implementation
* Add qmm_n
* Add gradient wrt to input for quantized_matmul
2024-01-03 14:22:36 -08:00
Nripesh Niketan
d7ac050f4b feat: Add contributors graph to README (#332)
* Fix: typo in README.md

* feat: Add contributors graph to README

* Update acknowledgments and contributors
2024-01-03 13:03:11 -08:00
Gabrijel Boduljak
c7edafb729 implemented InstanceNorm (#244)
* implemented instancenorm

* implemented vector_norm in cpp

added linalg to mlx

* implemented vector_norm python binding

* renamed vector_norm to norm, implemented norm without provided ord

* completed the implementation of the norm

* added tests

* removed unused import in linalg.cpp

* updated python bindings

* added some tests for python bindings

* handling inf, -inf as numpy does, more extensive tests of compatibility with numpy

* added better docs and examples

* refactored mlx.linalg.norm bindings

* reused existing util for implementation of linalg.norm

* more tests

* fixed a bug with no ord and axis provided

* removed unused imports

* some style and API consistency updates to linalg norm

* remove unused includes

* fix python tests

* fixed a bug with frobenius norm of a complex-valued matrix

* complex for vector too

* addressed PR review comments

* fixed import order in __init__

* expected values in instancenorm tests are simple lists

* minor return expression style change

* added InstanceNorm to docs

* doc string nits

* added myself to individual contributors

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-03 12:21:15 -08:00
Awni Hannun
dff4a3833f Module checks the weight on load_weights (#337)
* update module to check weights on load, also fix docs and reorganize tests

* nits + rebase

* a few more docs updates for Module

* use manual module file

* comment
2024-01-02 18:55:42 -08:00
Diogo
0782a4573a Add Tensordot op (#344) 2024-01-02 17:15:00 -08:00
Diogo
af66a09bde Adds issue template with common questions (#345)
* added template

* remove label
2024-01-02 16:52:20 -08:00
Angelos Katharopoulos
436bec9fd9 Fix the implementation of the Bilinear layer (#347) 2024-01-02 16:46:18 -08:00
Awni Hannun
99c80a2c8b Memory allocation (#292)
* try alternative gc

* try no cache

* add forced swap

* remove cache for now

* add cache back

* change fit crtieria

* remove unused function

* nit in comment

* tune / fix allocation

* increase block limit to original
2024-01-02 11:59:19 -08:00
Asaf Zorea
295ce9db09 Feature expand nn linear (#315)
* Added an identity and bilinear layers
Added a reset_parameters option
Added normal init for bias

* pre-commit run

* add type hints for parameters and the return type
change Bilinear math to x_1 and x_2
change __call__ arguments to x and y instead of input and output
add explanation to the Initialization

* Remove unnecessary reshape

* Added 'i' to bilinear formula

* Changed bilinear computation to two matrix multiplications

* avoid saving intermediate results, kept y in bilinear for better clarity (can be replaced with x1)

* Changed math formula in Linear
Added more explanation to math formulas
Changed x1, x2 reshape to support all inputs sizes
2024-01-02 06:08:53 -08:00
Josh Soref
44c1ce5e6a Spelling (#342)
* spelling: accumulates

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: across

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: additional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: against

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: among

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: array

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: at least

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: available

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: axes

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: basically

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bfloat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bounds

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: broadcast

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: buffer

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: class

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: coefficients

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: collision

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: combinations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: committing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: computation

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: consider

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: constructing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: conversions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: correctly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: corresponding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: declaration

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: default

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dependency

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destination

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destructor

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dimensions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: divided

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: element-wise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: elements

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: endianness

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: equivalent

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: explicitly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: github

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: indices

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: irregularly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: memory

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: metallib

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: negative

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: notable

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: optional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: otherwise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: overridden

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partially

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partition

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perform

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perturbations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: positively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: primitive

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeats

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respect

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respectively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: result

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: rounding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: separate

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: skipping

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: structure

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: the

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: transpose

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unnecessary

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unneeded

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unsupported

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

---------

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
2024-01-01 21:08:17 -08:00
Chunyang Wen
144ecff849 Remove useless import (#340)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2024-01-01 19:25:49 -08:00
mutexuan
350095ce6e fix type cast error in item() for bfloat16 (#339)
Co-authored-by: xuan <xuan@apple.com>
2024-01-01 19:02:04 -08:00
Nripesh Niketan
e09bf35b28 feat: Add Dropout3d layer to nn.layers (#313)
* feat: Add Dropout3d layer to nn.layers

* acknowledgement

* Add dropout tests to test_nn.py

* run pre-commit

* Add activation functions and dropout3d ops

* Add dropout tests for bfloat16 and float16
2023-12-31 14:01:21 -08:00
Daniel Strobusch
99c20f523e fix typos (#327) 2023-12-31 06:06:47 -08:00
Hazem Essam
e3b8da2a49 Added implementation for Scaled RoPE. (#261)
* Added scale for RoPE

* Ran pre-commit

* Added RoPE scaling test

* Added docstring for scale parameter

* Modified docstrings
2023-12-31 06:06:01 -08:00
Angelos Katharopoulos
a020a2d49d Improve repeat using broadcasting and reshape (#318) 2023-12-29 21:40:20 -08:00
Nripesh Niketan
930b159885 Fix: typo in README.md (#316) 2023-12-29 12:58:00 -08:00
Nripesh Niketan
5ad8fb7268 feat: add softsign, softmax, hardswish, logsoftmax activation function (#309)
* feat: add softsign activation function

* run pre-commit

* Add Softsign activation function

* Add Softsign activation function

* Add documentation for ReLU6, Softplus, and Softsign activations

* Update activation functions in neural network layers

* Add LogSoftmax and Hardswish activations

* run pre-commit

* Update activations.py

* Added acknowledgements

* Fix activation function comments

* Fix activation functions in neural network layers
2023-12-29 11:49:36 -08:00
Chunyang Wen
2aedf3e791 Minor refactor for tree_map and tree_unflatten (#311)
* Minor refact for tree_map and tree_unflatten

* Remove the if statement

---------

Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2023-12-28 20:55:10 -08:00
Chunyang Wen
473b6b43b4 Use defaultdict (#307)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2023-12-28 14:46:13 -08:00
Angelos Katharopoulos
d29770eeaa Update batchnorm to have the running stats in parameters (#305) 2023-12-28 14:31:10 -08:00
Chunyang Wen
040c3bafab Add missing f str (#306)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2023-12-28 06:09:34 -08:00
Chunyang Wen
05767b026f Add information for dropout probability (#304)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2023-12-27 21:51:30 -08:00
Diogo
a83d5d60bd Addition in acknowledgements (#302) 2023-12-27 13:46:47 -08:00
Bahaa
ff2b58e299 Add support for repeat (#278)
* add repeat function

* fix styling

* optimizing repeat

* fixed minor issues

* not sure why that folder is there xD

* fixed now for sure

* test repeat not repeat test

* Fixed

---------

Co-authored-by: Bahaa Eddin tabbakha <bahaa@Bahaas-MacBook-Pro.local>
2023-12-27 13:11:38 -08:00
YUN, Junwoo
4417e37ede Transformer fix (#167)
* add transformer with dropout, fix transformer ffm, layernorm order

* precommit changes

* precommit changes

* add docstring, activation, norm_first

* run precommit

* run precommit

* add doctstring

* precommit

* style nits in docs

---------

Co-authored-by: junwoo-yun <junwoo.yun@bagelcode.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-27 08:48:36 -08:00
Angelos Katharopoulos
79c95b6919 Fix load compilation (#298) 2023-12-27 06:20:45 -08:00
Diogo
1f6ab6a556 Safetensor support (#215)
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-27 02:06:55 -08:00
Gabrijel Boduljak
6b0d30bb85 linalg.norm (#187)
* implemented vector_norm in cpp

added linalg to mlx

* implemented vector_norm python binding

* renamed vector_norm to norm, implemented norm without provided ord

* completed the implementation of the norm

* added tests

* removed unused import in linalg.cpp

* updated python bindings

* added some tests for python bindings

* handling inf, -inf as numpy does, more extensive tests of compatibility with numpy

* added better docs and examples

* refactored mlx.linalg.norm bindings

* reused existing util for implementation of linalg.norm

* more tests

* fixed a bug with no ord and axis provided

* removed unused imports

* some style and API consistency updates to linalg norm

* remove unused includes

* fix python tests

* fixed a bug with frobenius norm of a complex-valued matrix

* complex for vector too

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-26 19:42:04 -08:00
Angelos Katharopoulos
447bc089b9 Fix tolerance in de-/quantization test (#295) 2023-12-26 19:21:05 -08:00
Yutaka Kondo
fc4e5b476b Fix llama link in README.md (#289) 2023-12-25 20:53:20 -08:00
Daniel Strobusch
d58ac083f3 expose itemsize and nbytes as for numpy arrays (#284)
see:
  * https://numpy.org/doc/stable/reference/generated/numpy.ndarray.nbytes.html
  * https://numpy.org/doc/stable/reference/generated/numpy.ndarray.itemsize.html

relates to https://github.com/ml-explore/mlx-examples/pull/174
2023-12-25 10:34:28 -08:00
__mo_san__
a123c3c7d2 implement-batch-norm-layer (#217)
- Add batch normalization layer

---------

Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-25 07:32:53 -08:00
Angelos Katharopoulos
9e6b8c9f48 Refactor the reduction kernels (#277) 2023-12-24 14:47:57 -08:00
Zach Schillaci
22fee5a383 Remove redundant assert in losses.py (#281) 2023-12-24 08:39:08 -08:00
Daniel Strobusch
7365d142a3 random.uniform must respect dtype, even if lower precision than "low" (#280)
Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array.
2023-12-24 07:04:43 -08:00
Awni Hannun
8b227fa9af fix no metal build (#276) 2023-12-23 19:18:10 -08:00
Vidit Agarwal
8c3da54c7d Fix failing test for log cosh loss (#275)
* fix assert statement in log_cosh_loss

* reformatted by pre-commit black
2023-12-23 16:26:46 -08:00
Vidit Agarwal
acf1721b98 Corrected the example of value_and_grad (#274)
* Corrected the example for mx.value_and_grad

* Reformat through pre-commit/black
2023-12-23 11:06:38 -08:00
Finn Voorhees
f91f450141 Fix argmax returns documentation (#263) 2023-12-22 20:33:17 -08:00
Ronan Collobert
cd3616a463 Revisit autorelease memory pools (#260)
* make general autorelease pool part of metal device

* make things simpler

* no metal backend support

* new_memory_pool -> new_scoped_memory_pool
2023-12-22 11:01:26 -08:00
Nicholas Santavas
d35fa1db41 Add Hinge, Huber and LogCosh losses (#199) 2023-12-22 10:28:10 -08:00
Justin Deschenaux
e8deca84e0 Add dropout2d (#250) 2023-12-22 08:02:29 -08:00
Angelos Katharopoulos
8385f93cea Bumping the version (#256) 2023-12-21 18:33:14 -08:00
Awni Hannun
2118c3dbfa fix (#255) 2023-12-21 18:18:41 -08:00
Awni Hannun
a002797d52 A temporary fix (#254) 2023-12-21 17:59:15 -08:00
Angelos Katharopoulos
1d053e0d1d Fix the alibi test that was left unchanged (#252) 2023-12-21 14:59:25 -08:00
Hazem Essam
0aa65c7a6b Added ALiBi implementation (#232) 2023-12-21 14:36:38 -08:00
Daniel Strobusch
794feb83df support arange for bfloat16 (#245) 2023-12-21 14:33:43 -08:00
Angelos Katharopoulos
2c7df6795e Make sure that arrays are freed when saving (#247) 2023-12-21 14:08:24 -08:00
Angelos Katharopoulos
b3916cbf2b Improve names of quantization arguments (#235)
* Change the default quantization group_size to 64
* Rename groups to group_size and width to bits
2023-12-20 16:53:53 -08:00
Angelos Katharopoulos
57fe918cf8 Adds C++ and nn quantization utilities (#230)
* Add C++ de-/quantize ops
* Add quantize functions to the docs and tests
* Add a QuantizedLinear module
2023-12-20 14:17:38 -08:00
Justin Deschenaux
4912ff3ec2 Add Lion optimizer (#209)
* Add Lion optimizer
* Update acknowledgements also with past contributions
2023-12-20 13:54:58 -08:00
Awni Hannun
f40d17047d Indexing bug (#233)
* fix

* test
2023-12-20 10:44:01 -08:00
Angelos Katharopoulos
2807c6aff0 Implements divide for integer types and adds floor_divide op (#228)
* Add floor_divide
* Add floor_divide to the tests
* Add floor_divide to the docs
2023-12-19 20:12:19 -08:00
davidkoski
de892cb66c fix for non-macos build issue on cblas.h (#227) 2023-12-19 17:01:59 -08:00
davidkoski
37024d899c fixes for building with swiftpm (#225)
- clbas is part of veclib (compile failure)
- add SWIFTPM_BUNDLE #define to allow loading the metallib from a swiftpm resource bundle
2023-12-19 16:22:10 -08:00
Diogo
137f55bf28 fail early if readinto does not exist (#221) 2023-12-19 13:27:17 -08:00
Emircan Erol
e549f84532 Triplet Loss (#211)
* Triplet Loss

* Requested Changes

* Margin to alpha
2023-12-19 12:37:12 -08:00
Angelos Katharopoulos
dfa9f4bc58 An initial quantized matmul implementation (#205)
* Add quantized matvec
* Add quantized matrix matrix with 2nd matrix transposed
* Add quantized matmul tests
* Add a slow cpu quantized matmul
* Add a slightly faster vectorized cpu version
2023-12-18 23:18:57 -08:00
Abe Leininger
e6872a4149 Added linspace (#181)
* linspace ops support

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-18 19:57:55 -08:00
Juarez Bochi
f4f6e17d45 Fix cross-attention (#210)
* Fix cross-attention

With the current code, ln2 is a no-op. Its output should be passed to the cross-attention layer

* Add name to contributors
2023-12-18 12:27:27 -08:00
Angelos Katharopoulos
4d4af12c6f Adds round op and primitive (#203) 2023-12-18 11:32:48 -08:00
Awni Hannun
477397bc98 Citation + Contributor acknowledgment section (#207)
* cite

* nits

* nits

* comment
2023-12-18 10:07:00 -08:00
jojopuppet
18cca64c81 Add smoothed L1 loss and enhancements to cross entropy loss (#166)
* Add smooth_l1_loss
* Add labels moothing for cross entropy loss

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-18 07:26:21 -08:00
Awni Hannun
0e5807bbcb include optional (#202) 2023-12-17 22:01:35 -08:00
Cyril Zakka, MD
8eb56beb3a Added clip function (#159)
* Added clip

* Added Python bindings

* Formatting

* Added cpp tests

* Added Python tests

* python bindings work

* rebase

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-17 20:00:29 -08:00
Awni Hannun
ee0c2835c5 Docs updates (#198)
Reorganize NN docs + a few other tidbits.
2023-12-17 13:20:55 -08:00
Awni Hannun
90d04072b7 fix build w/ flatten (#195) 2023-12-17 11:58:45 -08:00
__mo_san__
52e1589a52 implemented Flatten Module (#149)
* implemented flatten op

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-16 21:54:37 -08:00
YUN, Junwoo
eebd7c275d Add optimizers (AdaMax, AdaDelta, RMSprop) and ordering optimizer classes (#142)
* Add AdaMax, AdaDelta, RMSprop
2023-12-16 21:43:15 -08:00
Austin Liu
a67bbfe745 Update docs (#177) (#190)
* update docs (fix #177)

* reorder

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-16 06:52:18 -08:00
Awni Hannun
104c34f906 setite negative indexing bug (#189) 2023-12-16 06:44:47 -08:00
Diogo
dc2edc762c added tri / tril / triu (#170)
* added tri / tril / triu

* fixed tests

* ctest tests

* tri overload and simplified tests

* changes from comment

* more tests for m

* ensure assert if not 2-D

* remove broadcast_to

* minor tweaks

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-15 17:30:34 -08:00
Awni Hannun
2e02acdc83 add base kwarg to rope (#186) 2023-12-15 16:47:59 -08:00
Ronan Collobert
83f266c44c Lazy metal_device_ initialization (#185)
This ensures it is defined when the Scheduler needs it.
2023-12-15 16:06:46 -08:00
Víctor Aguilar
f24200db2c accross -> across (#183) 2023-12-15 13:46:50 -08:00
Jason
e28b57e371 Added mx.stack c++ frontend impl (#123)
* stack C++ operation + python bindings
2023-12-14 13:21:19 -08:00
Awni Hannun
e5851e52b1 Add move and swap axis, and vmap for slice, concat, and gather (#158)
* add move and swap axis, and vmap for slice, concat, and gather
2023-12-14 12:59:12 -08:00
Diogo
f55908bc48 Added stubs for python files generated from C++ (#136)
* added pybind11-stubgen

* docs for generating stubs

* added line to readme
2023-12-14 12:58:45 -08:00
Luca Arnaboldi
b93c4cf378 Floor and Ceil (#150)
* Implements Floor and Ceil Ops
2023-12-14 10:00:23 -08:00
Stv.X
1e0c78b970 Fixed typo in some proprietary terms. (#161) 2023-12-13 19:48:00 -08:00
Awni Hannun
76e1af0e02 bump version (#157) 2023-12-13 14:28:26 -08:00
Ikko Eltociear Ashimine
c3272d4917 Update conv.cpp (#145)
Peform -> Perform
2023-12-12 11:27:49 -08:00
SputNikPlop
50f5d14b11 fix: tidy pull request template (#143)
* fix: tidy pull request template

* fix: feedback from awni
2023-12-12 08:14:39 -08:00
noahsmartin
d14a0e4ff9 Docs update (#144) 2023-12-12 07:53:42 -08:00
Diogo
fb675de30d Run lint check for prs (#139) 2023-12-12 00:23:33 -08:00
Awni Hannun
25f70d4ca4 Fix divide types + floor divide (//) (#138)
* divide types

* fix black + test
2023-12-11 20:20:58 -08:00
Diogo
02de234ef0 Activations LeakyReLU / PReLU / Softplus / Mish (#109)
* Leaky_relu / prelu / softplus / mish

* added tests

* updated bench

* remove torch refs, add init to PReLU

* added arvix reference to mish

* added missing docs
2023-12-11 19:40:57 -08:00
Nicholas Santavas
f5df47ec6e Add Step, ELU, SELU, Swish activation functions (#117)
* Add Step, ELU, SELU, Swish activation functions

This commit adds the Step, ELU, SELU and Swish activations functions

* add to the docs

* review
2023-12-11 17:04:07 -08:00
Awni Hannun
b9226c367c Fix CI format + build issue (#137)
* fix ci

* Fix python bindings build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2023-12-11 15:01:41 -08:00
Angelos Katharopoulos
3214629601 Mlx array accessor (#128)
* Add an accessor to interoperate with custom types
* Change the docs to custom signatures
2023-12-11 13:42:55 -08:00
__mo_san__
072044e28f fix and update binary cross entropy loss tests (#133)
* fix conflicts

* updated tests
2023-12-11 12:42:17 -08:00
Cyril Zakka, MD
e080290ba4 Added eye/identity ops (#119)
`eye` and `identity` C++ and Python ops
2023-12-11 12:38:17 -08:00
Awni Hannun
69505b4e9b fixes (#131) 2023-12-11 09:26:49 -08:00
__mo_san__
f4ddd7dc44 Add Binary Cross Entropy loss (#122)
* update BCE added tests for it ...

* added binary cross entropy loss to docs

* resolving conflicts for merge
2023-12-11 07:55:18 -08:00
Jason
b0cd092b7f Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid (#108)
* added leaky_relu relu6 softplus elu celu logsigmoid
* minor fixes for docstring and benchmark imports
* fixed elu implementation and added tests
* added tests for optional param, changed leaky_relu param to fit pytorch documentation
2023-12-10 16:31:38 -08:00
Awni Hannun
71d1fff90a Bug fix in metal binary kernel dispatch for large arrays (#125)
* bug fix

* format
2023-12-10 16:12:31 -08:00
Yiyang(Steven) Yu
0cfbfc9904 Update README.md (#121) 2023-12-10 14:47:37 -08:00
Awni Hannun
2d0130f80f fix loss tests (#118)
* fix loss tests

* use none as default
2023-12-10 10:08:19 -08:00
__mo_san__
c1e1c1443f Added Adagrad optimizer (#102) 2023-12-10 09:22:39 -08:00
Henry Ansah
68bf1d7867 add nn module for sigmoid activation (#111)
* add nn module for sigmoid activation

* update .gitignore with .cache folder generated by jetbrains fleet ide

* remove .cache folder
2023-12-10 07:00:39 -08:00
Angelos Katharopoulos
600db7d754 Fix build on Xcode 14 (#116)
* Fix build on Xcode 14

* Style fixes
2023-12-10 06:58:52 -08:00
__mo_san__
ef7b8756c0 Add tanh activation function (#115)
* added Adagrad optimizer ...

* added Tanh activation function ...

* reformatted file ...

* remove unrelated stuff ...

* Update activations.py
2023-12-09 19:25:38 -08:00
Enoch Kan
0b28399638 added mse_loss, nll_loss and kl_div_loss (#98)
* added mse_loss, nll_loss and kl_div_loss

* fixed axis not defined error in nll_loss

* fixed axis not defined in kl_div_loss

* added tests for mse, nll and kl_div

* modified docstrings and added reduce helper func

* updated docstring in kl_div_loss and moved helper func

* added new kl divergence implementation

* added reduction to test

* updated docstring of kl_div_loss with correct spelling

* added losses to nn.rst in docs
2023-12-09 14:25:03 -08:00
Joe Barrow
ac6dc5d3eb Adding optional bias param to MultiHeadAttention (#104)
* Adding optional  param to

* Run style-checker
2023-12-09 11:04:28 -08:00
Awni Hannun
89b90dcfec Pr template (#99)
* pr template
* format fix
2023-12-09 09:36:56 -08:00
Angelos Katharopoulos
fd836d891b Hashable dtype and mlx.core prefixed repr (#89)
* Make dtype hashable
* Add mlx.core prefix to our dtypes' repr
* Update the dtype test
2023-12-09 09:35:28 -08:00
AtomicVar
976e8babbe Use compiled black as the pre-commit formatter (#94) 2023-12-09 07:06:46 -08:00
Awni Hannun
2520dbcf0a add losses to the docs, fix black failur (#92) 2023-12-09 06:06:52 -08:00
Abe Leininger
430bfb4944 Adds Nesterov momentum to SGD (#87) 2023-12-08 23:23:36 -08:00
ShiJZ
08d51bf232 Make it easier to test new optimizers implemented: no need to change test file manually (#90)
* add helper function get_all_optimizers() in test_optimizers.py

* remove unused import
2023-12-08 21:39:08 -08:00
Kai Ma
cb9e585b8e Style fix for loss functions (#91)
* MLE and L1 loss functions

* logsoftmax change and tests

* subtract max logit for numerical stability

* l1 name change

* cross entropy reduction + unit tests

* docstrings

* l1 test name change

* old loss impl + default none

* style
2023-12-08 21:11:56 -08:00
Kai Ma
641d316484 MLE and L1 loss functions (#88)
* MLE and L1 loss functions

* logsoftmax change and tests

* subtract max logit for numerical stability

* l1 name change

* cross entropy reduction + unit tests

* docstrings

* l1 test name change

* old loss impl + default none
2023-12-08 20:21:37 -08:00
Angelos Katharopoulos
2b714714e1 Add the remainder op (#85)
* Add remainder in the C++ backend
* Add the python binding and test
2023-12-08 15:08:52 -08:00
Joe Barrow
69a24e6a1e AdamW implementation (#72)
* AdamW implementation without bias correction
* Makes use of the underlying Adam implementation
2023-12-08 14:45:34 -08:00
Zach Schillaci
5b9be57ac3 Add isort pre-commit and run (#68) 2023-12-08 11:31:47 -08:00
Jagrit Digani
e89c571de7 Update cmake to detect and throw warnings if not on a arm system (#81)
* Update cmake to detect and throw warnings if not on a arm system
2023-12-08 11:03:25 -08:00
Angelos Katharopoulos
209404239b Fix the accelerate dispatch for the power op (#70)
- The exponent and base were swapped because accelerate is using
  exponent-base instead of base-exponent
- Fix also the test for binary ops as it was testing op(x, x) which
  couldn't catch ordering errors like that
2023-12-08 10:58:03 -08:00
Awni Hannun
4e3bdb560c random generation fix (#80)
Random generation fix
2023-12-08 10:40:57 -08:00
Gautam krishna R
86b614afcd added long_description for pypi readme (#69) 2023-12-08 03:33:29 -08:00
Awni Hannun
cfc39d84b7 Some docs on unified memory (#62)
* doc on unified memory
2023-12-07 19:42:24 -08:00
Zach Schillaci
d11d77e581 Spelling fixes in transformer.py (#59) 2023-12-07 13:32:09 -08:00
Jagrit Digani
bf410cb85e Update CMake to not try and build metallib if Metal framework not found (#55) 2023-12-07 09:48:42 -08:00
rushyam
2e126aeb7e Feature Addition: Encoder-Decoder Transformer Architecture (#50)
* Implemented decoder-transformer-layer, decoder-transformer  and introduce encoder-decoder transformer

* added relu layer

* add src, tgt, memory mask

---------

Co-authored-by: rushyam <rushyam@rushyams-MacBook-Air.local>
2023-12-07 07:37:36 -08:00
Awni Hannun
dfbc52ce56 Install docs + python versions (#53)
* install + python versions

* add link in install docs

* add link
2023-12-07 07:29:17 -08:00
Angelos Katharopoulos
43e336cff2 Bump the version (#47) 2023-12-07 06:40:55 -08:00
Awni Hannun
d895e38f2e Nits (#38)
* include 3.12, black format

* circle ci badge

* format
2023-12-06 13:32:41 -08:00
Diogo
d15dead35e add extra_require with libs for running tests (#36) 2023-12-06 12:21:48 -08:00
Jagrit Digani
2440fe0124 NPY loading segfault bug (#34)
* Fixed Gil semantics in loading and saving from python file streams
2023-12-06 12:03:47 -08:00
Awni Hannun
170e4b2d43 fix links (#32) 2023-12-06 08:12:06 -08:00
Jagrit Digani
2629cc8682 Install docs update (#29)
* Add notes about MacOS version restrictions for mlx in install docs 
* Add notes about Xcode version requirements for building from source in install docs
* Let make detect the macosx sdk version being used 
* Throw error if trying to build metal kernels with macOS <= 13.4 
* Add metal-cpp for macOS 14.2
2023-12-06 08:10:51 -08:00
Ikko Eltociear Ashimine
9f4cf2e0fe Update extensions.rst (#26)
unecessary -> unnecessary
2023-12-06 07:18:28 -08:00
Markus Enzweiler
2ffaee0c0d Updated default argument for stride to 1 in Conv2d() in the docstring (#22) 2023-12-06 07:17:58 -08:00
Yingbo Ma
36b245b287 Fix benchmark example (#11) 2023-12-06 07:17:16 -08:00
Esakkivel Esakkiraja
8c96b9a890 Update README.md (#9)
- Fixed typo and other minor errors
2023-12-05 21:31:27 -08:00
Angelos Katharopoulos
07897a346d Bump the version (#8)
* Bump the version
* Change the version in the docs as well
2023-12-05 17:46:08 -08:00
Jagrit Digani
d518b3b6a5 Fix gemv broadcasting bug (#6)
* Fix broadcasting bug in gemv
* Add relevant tests in test_blas.py
2023-12-05 14:15:43 -08:00
Awni Hannun
49cda449b1 apple mlr (#7) 2023-12-05 14:10:59 -08:00
Awni Hannun
6449a8682a Doc theme (#5)
* change docs theme + links + logo

* move mlx intro to landing page
2023-12-05 12:08:05 -08:00
697 changed files with 129082 additions and 23563 deletions

View File

@@ -1,5 +1,8 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
parameters:
nightly_build:
type: boolean
@@ -7,8 +10,65 @@ parameters:
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
jobs:
build_documentation:
parameters:
upload-docs:
type: boolean
default: false
macos:
xcode: "16.2.0"
resource_class: m2pro.medium
steps:
- checkout
- run:
name: Install
command: |
brew install python@3.9
brew install doxygen
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
steps:
- run:
name: Build documentation
command: |
source env/bin/activate
cd docs && doxygen && make html O=-W
- when:
condition: << parameters.upload-docs >>
steps:
- add_ssh_keys:
fingerprints:
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
- run:
name: Upload documentation
command: |
source env/bin/activate
git config user.email "mlx@group.apple.com"
git config user.name "CircleCI Docs"
git checkout gh-pages
git rebase main
cd docs
git rm -rf build/html
doxygen && make html O=-W
git add -f build/html
git commit -m "rebase"
git push -f origin gh-pages
linux_build_and_test:
docker:
- image: cimg/python:3.9
@@ -25,176 +85,303 @@ jobs:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install nanobind==2.4.0
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
- run:
name: Build python package
name: Install Python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
- run:
name: Run the python tests
name: Generate package stubs
command: |
python3 -m unittest discover python/tests
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run:
name: Build CPP only
command: |
mkdir -p build && cd build && cmake .. -DMLX_BUILD_METAL=OFF && make -j
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests
mac_build_and_test:
machine: true
resource_class: ml-explore/m-builder
parameters:
xcode_version:
type: string
default: "16.2.0"
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
steps:
- checkout
- run:
name: Install dependencies
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=3.9
conda activate runner-env
brew install python@3.9
brew install openmpi
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install nanobind==2.4.0
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
- run:
name: Build python package
name: Install Python package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop
source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v
- run:
name: Run the python tests
name: Generate package stubs
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run:
name: Build example extension
command: |
source env/bin/activate
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext -j8
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
- run:
name: Build small binary
command: |
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j `sysctl -n hw.ncpu`
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
cuda_build_and_test:
machine:
image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- run:
name: Install Python package
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
python -m venv env
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install -e ".[dev]"
- run:
name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
build_release:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
macos_version:
xcode_version:
type: string
default: "14"
default: "16.2.0"
build_env:
type: string
default: ""
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: m2pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps:
- checkout
- run:
name: Install dependencies
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install twine
pip install build
- run:
name: Build pacakge
name: Install Python package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
PYPI_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py bdist_wheel
twine upload dist/* --repository mlx
source env/bin/activate
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
source env/bin/activate
twine upload dist/*
- store_artifacts:
path: dist/
build_dev_release:
machine: true
resource_class: ml-explore/m-builder
build_linux_release:
parameters:
python_version:
type: string
default: "3.9"
macos_version:
extra_env:
type: string
default: "14"
default: "DEV_RELEASE=1"
docker:
- image: ubuntu:20.04
steps:
- checkout
- run:
name: Install dependencies
name: Build wheel
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
PYTHON=python<< parameters.python_version >>
apt-get update
apt-get upgrade -y
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
apt-get install -y apt-utils
apt-get install -y software-properties-common
add-apt-repository -y ppa:deadsnakes/ppa
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get install -y build-essential git
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v
pip install typing_extensions
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run:
name: Build pacakge
name: Upload package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py bdist_wheel
twine upload dist/* --repository mlx
source env/bin/activate
twine upload wheelhouse/*
- store_artifacts:
path: dist/
build_package:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
macos_version:
type: string
default: "14"
steps:
- checkout
- run:
name: Install dependencies
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install numpy
pip install twine
- run:
name: Build pacakge
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py bdist_wheel
- store_artifacts:
path: dist/
path: wheelhouse/
workflows:
build_and_test:
when:
and:
- matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test
- mac_build_and_test
- cuda_build_and_test
- build_documentation
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- build_release:
filters:
tags:
@@ -203,21 +390,238 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11"]
macos_version: ["13", "14"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
upload-docs: true
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test:
requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
nightly_build:
when: << pipeline.parameters.nightly_build >>
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.nightly_build >>
jobs:
- build_package:
- build_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11"]
macos_version: ["13", "14"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
weekly_build:
when: << pipeline.parameters.weekly_build >>
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.weekly_build >>
jobs:
- build_dev_release:
- build_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11"]
macos_version: ["13", "14"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.linux_release >>
jobs:
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]

28
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@@ -0,0 +1,28 @@
---
name: Bug report
about: Create a report about an issue you've encountered
title: "[BUG] "
labels: ''
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Include code snippet
```python
```
**Expected behavior**
A clear and concise description of what you expected to happen.
**Desktop (please complete the following information):**
- OS Version: [e.g. MacOS 14.1.2]
- Version [e.g. 0.7.0]
**Additional context**
Add any other context about the problem here.

12
.github/pull_request_template.md vendored Normal file
View File

@@ -0,0 +1,12 @@
## Proposed changes
Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #.
## Checklist
Put an `x` in the boxes that apply.
- [ ] I have read the [CONTRIBUTING](https://github.com/ml-explore/mlx/blob/main/CONTRIBUTING.md) document
- [ ] I have run `pre-commit run --all-files` to format my code / installed pre-commit prior to committing changes
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] I have updated the necessary documentation (if needed)

20
.github/workflows/pull_request.yml vendored Normal file
View File

@@ -0,0 +1,20 @@
on:
pull_request:
branches:
- main
jobs:
check_lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files

13
.gitignore vendored
View File

@@ -6,10 +6,16 @@ __pycache__/
# C extensions
*.so
# tensor files
*.safe
*.safetensors
# Metal libraries
*.metallib
venv/
# Distribution / packaging
python/mlx/core
python/mlx/share
python/mlx/include
.Python
@@ -30,6 +36,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
uv.lock
# vim
*.swp
@@ -70,6 +77,12 @@ build/
*.out
*.app
# Debug symbols
*.pdb
# VSCode
.vscode/
.DS_Store
# Jetbrains
.cache

View File

@@ -1,9 +1,21 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v14.0.6
rev: v19.1.7
hooks:
- id: clang-format
- repo: https://github.com/psf/black
rev: 22.10.0
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.1.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 6.0.0
hooks:
- id: isort
args:
- --profile=black
- repo: https://github.com/cheshirekow/cmake-format-precommit
rev: v0.6.13
hooks:
- id: cmake-format

View File

@@ -1,3 +1,31 @@
# Individual Contributors
If you wish to be acknowledged for your contributions, please list your name
with a short description of your contribution(s) below. For example:
- Jane Smith: Added the `foo` and `bar` ops.
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>
# Third-Party Software
MLX leverages several third-party software, listed here together with
their license copied verbatim.

24
CITATION.cff Normal file
View File

@@ -0,0 +1,24 @@
cff-version: 1.2.0
title: mlx
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Awni
family-names: Hannun
affiliation: Apple
- given-names: Jagrit
family-names: Digani
affiliation: Apple
- given-names: Angelos
family-names: Katharopoulos
affiliation: Apple
- given-names: Ronan
family-names: Collobert
affiliation: Apple
repository-code: 'https://github.com/ml-explore'
abstract: >-
MLX: efficient and flexible machine learning on Apple
silicon
license: MIT

View File

@@ -1,6 +1,24 @@
cmake_minimum_required(VERSION 3.24)
cmake_minimum_required(VERSION 3.25)
project(mlx LANGUAGES CXX)
if(NOT MLX_VERSION)
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
set(_major ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
set(_minor ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
set(_patch ${CMAKE_MATCH_1})
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
set(MLX_VERSION ${MLX_PROJECT_VERSION})
else()
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
${MLX_VERSION})
endif()
project(
mlx
LANGUAGES C CXX
VERSION ${MLX_PROJECT_VERSION})
# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
@@ -15,10 +33,41 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.0.1)
# --------------------- Processor tests -------------------------
message(
STATUS
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
)
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC)
message(
FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
)
else()
set(MLX_BUILD_METAL OFF)
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
endif()
else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif()
# ----------------------------- Lib -----------------------------
@@ -29,99 +78,194 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
if (MLX_BUILD_METAL)
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
find_library(QUARTZ_LIB QuartzCore)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
set(FOUNDATION_LIB "-framework Foundation")
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if (MLX_BUILD_METAL AND NOT METAL_LIB)
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
elseif (MLX_BUILD_METAL)
set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif(MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
add_compile_definitions(_METAL_)
execute_process(COMMAND zsh "-c" "/usr/bin/sw_vers | cut -f2- -d: | sed -n 2p | grep -Eo '[0-9]+.[0-9]+'"
OUTPUT_VARIABLE MACOS_VERSION)
message(STATUS "Detected macOS version ${MACOS_VERSION}")
if (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
else()
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13_iOS16.zip)
if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
FetchContent_Declare(
metal_cpp
URL ${METAL_CPP_URL}
)
# Throw an error if xcrun not found
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
FATAL_ERROR
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
endif()
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
execute_process(
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
mlx PUBLIC
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>
)
target_link_libraries(
mlx
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()
find_library(ACCELERATE_LIBRARY Accelerate)
if (ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
#set(BLA_VENDOR Generic)
find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
set(MLX_BUILD_GGUF OFF)
# There is no prebuilt OpenBLAS distribution for MSVC.
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS ${BLAS_LIBRARIES})
message(STATUS ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
# Windows implementation of dlfcn.h APIs.
FetchContent_Declare(
dlfcn-win32
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
GIT_TAG v1.4.1
EXCLUDE_FROM_ALL)
block()
set(BUILD_SHARED_LIBS OFF)
FetchContent_MakeAvailable(dlfcn-win32)
endblock()
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
target_link_libraries(mlx PRIVATE dl)
endif()
if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate)
if(ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()
if(MLX_BUILD_ACCELERATE)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(MLX_USE_ACCELERATE)
add_compile_definitions(ACCELERATE_NEW_LAPACK)
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
# Download and build OpenBLAS from source code.
FetchContent_Declare(
openblas
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
GIT_TAG v0.3.28
EXCLUDE_FROM_ALL)
set(BUILD_STATIC_LIBS ON) # link statically
set(NOFORTRAN ON) # msvc has no fortran compiler
FetchContent_MakeAvailable(openblas)
target_link_libraries(mlx PRIVATE openblas)
target_include_directories(
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
else()
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif()
# Search and link with lapack.
find_package(LAPACK REQUIRED)
if(NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old
# version of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
if(NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
endif()
else()
set(MLX_BUILD_ACCELERATE OFF)
endif()
message(STATUS "Downloading json")
FetchContent_Declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories(
mlx
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>
)
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>)
if (MLX_BUILD_PYTHON_BINDINGS)
# Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE nanobind_ROOT)
find_package(nanobind CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif()
if (MLX_BUILD_TESTS)
if(MLX_BUILD_TESTS)
include(CTest)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
endif()
if (MLX_BUILD_EXAMPLES)
if(MLX_BUILD_EXAMPLES)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
endif()
if (MLX_BUILD_BENCHMARKS)
if(MLX_BUILD_BENCHMARKS)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
endif()
@@ -130,32 +274,31 @@ include(GNUInstallDirs)
# Install library
install(
TARGETS mlx
EXPORT MLXTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)
TARGETS mlx
EXPORT MLXTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
INCLUDES
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
# Install headers
install(
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT headers
FILES_MATCHING PATTERN "*.h"
)
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT headers
FILES_MATCHING
PATTERN "*.h"
PATTERN "backend/metal/kernels.h" EXCLUDE)
# Install metal dependencies
if (MLX_BUILD_METAL)
if(MLX_BUILD_METAL)
# Install metal cpp
install(
DIRECTORY ${metal_cpp_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
COMPONENT metal_cpp_source
)
DIRECTORY ${metal_cpp_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
COMPONENT metal_cpp_source)
endif()
@@ -167,31 +310,24 @@ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
install(
EXPORT MLXTargets
FILE MLXTargets.cmake
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
include(CMakePackageConfigHelpers)
write_basic_package_version_file(
${MLX_CMAKE_BUILD_VERSION_CONFIG}
COMPATIBILITY SameMajorVersion
VERSION ${MLX_VERSION}
)
VERSION ${MLX_VERSION})
configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in
${MLX_CMAKE_BUILD_CONFIG}
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
NO_CHECK_REQUIRED_COMPONENTS_MACRO
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR
)
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
MLX_CMAKE_INSTALL_MODULE_DIR)
install(
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
install(
DIRECTORY ${CMAKE_MODULE_PATH}/
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
install(DIRECTORY ${CMAKE_MODULE_PATH}/
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})

View File

@@ -5,26 +5,26 @@ possible.
## Pull Requests
1. Fork and submit pull requests to the repo.
1. Fork and submit pull requests to the repo.
2. If you've added code that should be tested, add tests.
3. If a change is likely to impact efficiency, run some of the benchmarks before
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
4. If you've changed APIs, update the documentation.
5. Every PR should have passing tests and at least one review.
5. Every PR should have passing tests and at least one review.
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
This should install hooks for running `black` and `clang-format` to ensure
consistent style for C++ and python code.
You can also run the formatters manually as follows:
```
clang-format -i file.cpp
```
```
black file.py
```
```shell
clang-format -i file.cpp
```
```shell
black file.py
```
or run `pre-commit run --all-files` to check all files in the repo.
## Issues

View File

@@ -1,3 +1,6 @@
include CMakeLists.txt
include mlx.pc.in
recursive-include mlx/ *
include cmake/*
include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561

View File

@@ -2,38 +2,43 @@
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
[**Examples**](#examples)
[**Examples**](#examples)
MLX is an array framework for machine learning on Apple silicon.
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
MLX is an array framework for machine learning on Apple silicon,
brought to you by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API which closely follows NumPy.
MLX also has a fully featured C++ API which closely mirrors the Python API.
MLX has higher level packages like `mlx.nn` and `mlx.optimizers` with APIs
that closely follow PyTorch to simplify building more complex models.
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
- **Composable function transformations**: MLX has composable function
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Dynamic graph construction**: Computation graphs in MLX are built
- **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive.
- **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and GPU).
(currently the CPU and the GPU).
- **Unified memory**: A noteable difference from MLX and other frameworks
- **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported
device types without moving data.
device types without transferring data.
MLX is designed by machine learning researchers for machine learning
researchers. The framework is intended to be user friendly, but still efficient
researchers. The framework is intended to be user-friendly, but still efficient
to train and deploy models. The design of the framework itself is also
conceptually simple. We intend to make it easy for researchers to extend and
improve MLX with the goal of quickly exploring new ideas.
@@ -46,11 +51,11 @@ The design of MLX is inspired by frameworks like
## Examples
The [MLX examples repo](https://github.com/ml-explore/mlx-examples) has a
variety of examples including:
variety of examples, including:
- [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
- Large scale text generation with
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llama) and
- Large-scale text generation with
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and
finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
- Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
- Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).
@@ -58,22 +63,54 @@ variety of examples including:
## Quickstart
See the [quick start
guide](https://ml-explore.github.io/mlx/build/html/quick_start.html)
guide](https://ml-explore.github.io/mlx/build/html/usage/quick_start.html)
in the documentation.
## Installation
MLX is available on [PyPi](https://pypi.org/project/mlx/). To install the Python API run:
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
**With `pip`**:
```
pip install mlx
```
**With `conda`**:
```
conda install -c conda-forge mlx
```
Checkout the
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
for more information on building the C++ and Python APIs from source.
## Contributing
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
on contributing to MLX.
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
on contributing to MLX. See the
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
information on building from source, and running tests.
We are grateful for all of [our
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
to MLX and wish to be acknowledged, please add your name to the list in your
pull request.
## Citing MLX
The MLX software suite was initially developed with equal contribution by Awni
Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX useful in your research and wish to cite it, please use the following
BibTex entry:
```
@software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
url = {https://github.com/ml-explore},
version = {0.0},
year = {2023},
}
```

View File

@@ -5,35 +5,35 @@
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
namespace mx = mlx::core;
void time_value_and_grad() {
auto x = ones({200, 1000});
eval(x);
auto fn = [](array x) {
auto x = mx::ones({200, 1000});
mx::eval(x);
auto fn = [](mx::array x) {
for (int i = 0; i < 20; ++i) {
x = log(exp(x));
x = mx::log(mx::exp(x));
}
return sum(x);
return mx::sum(x);
};
auto grad_fn = grad(fn);
auto grad_fn = mx::grad(fn);
auto independent_value_and_grad = [&]() {
auto value = fn(x);
auto dfdx = grad_fn(x);
return std::vector<array>{value, dfdx};
return std::vector<mx::array>{value, dfdx};
};
TIME(independent_value_and_grad);
auto value_and_grad_fn = value_and_grad(fn);
auto value_and_grad_fn = mx::value_and_grad(fn);
auto combined_value_and_grad = [&]() {
auto [value, dfdx] = value_and_grad_fn(x);
return std::vector<array>{value, dfdx};
return std::vector<mx::array>{value, dfdx};
};
TIME(combined_value_and_grad);
}
int main() {
std::cout << "Benchmarks for " << default_device() << std::endl;
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
time_value_and_grad();
}

View File

@@ -4,21 +4,21 @@
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
namespace mx = mlx::core;
void time_add_op() {
std::vector<int> sizes(1, 1);
for (int i = 0; i < 9; ++i) {
sizes.push_back(10 * sizes.back());
}
set_default_device(Device::cpu);
set_default_device(mx::Device::cpu);
for (auto size : sizes) {
auto a = random::uniform({size});
auto b = random::uniform({size});
eval(a, b);
auto a = mx::random::uniform({size});
auto b = mx::random::uniform({size});
mx::eval(a, b);
std::cout << "Size " << size << std::endl;
TIMEM("cpu", add, a, b, Device::cpu);
TIMEM("gpu", add, a, b, Device::gpu);
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
}
}

View File

@@ -1,110 +1,111 @@
// Copyright © 2023 Apple Inc.
#include <cstring>
#include <iostream>
#include <sstream>
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
namespace mx = mlx::core;
void time_irregular_binary_ops_1D() {
auto device = default_device();
auto device = mx::default_device();
int size = 1000000;
int step = 2;
auto a = random::uniform({size});
auto b = random::uniform({size});
eval(a, b);
auto a = mx::random::uniform({size});
auto b = mx::random::uniform({size});
mx::eval(a, b);
a = slice(a, {0}, {size}, {step});
b = slice(b, {0}, {size}, {step});
TIMEM("1D strided", add, a, b, device);
TIMEM("1D strided", mx::add, a, b, device);
}
void time_irregular_binary_ops_2D() {
auto device = default_device();
auto device = mx::default_device();
int size = 2048;
auto a = random::uniform({size, size});
auto b = random::uniform({size, size});
eval(a, b);
TIMEM("2D regular", add, a, b, device);
auto a = mx::random::uniform({size, size});
auto b = mx::random::uniform({size, size});
mx::eval(a, b);
TIMEM("2D regular", mx::add, a, b, device);
b = transpose(b);
eval(b);
TIMEM("2D transpose", add, a, b, device);
b = mx::transpose(b);
mx::eval(b);
TIMEM("2D mx::transpose", mx::add, a, b, device);
b = random::uniform({size});
eval(b);
TIMEM("2D broadcast dim 0", add, a, b, device);
b = mx::random::uniform({size});
mx::eval(b);
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
b = reshape(b, {size, 1});
eval(b);
TIMEM("2D broadcast dim 1", add, a, b, device);
b = mx::reshape(b, {size, 1});
mx::eval(b);
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
}
void time_irregular_binary_ops_3D() {
auto device = default_device();
auto device = mx::default_device();
int d0 = 32;
int d1 = 512;
int d2 = 512;
auto a = random::uniform({d0, d1, d2});
auto b = random::uniform({d0, d1, d2});
TIMEM("3D regular", add, a, b, device);
auto a = mx::random::uniform({d0, d1, d2});
auto b = mx::random::uniform({d0, d1, d2});
TIMEM("3D regular", mx::add, a, b, device);
b = transpose(b, {0, 2, 1});
TIMEM("3D transpose", add, a, b, device);
b = mx::transpose(b, {0, 2, 1});
TIMEM("3D mx::transpose", mx::add, a, b, device);
b = random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", add, a, b, device);
b = mx::random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
b = random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", add, a, b, device);
b = mx::random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
b = random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", add, a, b, device);
b = mx::random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
b = random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
b = mx::random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
b = random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
b = mx::random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
b = random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
b = mx::random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
}
void time_irregular_binary_ops_4D() {
auto device = default_device();
auto device = mx::default_device();
std::vector<int> shape = {8, 8, 512, 512};
auto a = random::uniform(shape);
auto b = random::uniform(shape);
auto a = mx::random::uniform(shape);
auto b = mx::random::uniform(shape);
TIMEM("4D regular", add, a, b, device);
TIMEM("4D regular", mx::add, a, b, device);
b = transpose(b, {0, 1, 3, 2});
TIMEM("4D transpose", add, a, b, device);
b = mx::transpose(b, {0, 1, 3, 2});
TIMEM("4D mx::transpose", mx::add, a, b, device);
std::string om = "4D broadcast dims ";
for (int i = 0; i < shape.size(); ++i) {
shape[i] = 1;
b = random::uniform(shape);
b = mx::random::uniform(shape);
std::ostringstream msg;
msg << om << i;
TIMEM(msg.str(), add, a, b, device);
TIMEM(msg.str(), mx::add, a, b, device);
for (int j = i + 1; j < shape.size(); ++j) {
shape[j] = 1;
std::ostringstream msg;
msg << om << i << ", " << j;
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
b = mx::random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device);
shape[j] = a.shape(j);
for (int k = j + 1; k < shape.size(); ++k) {
shape[k] = 1;
std::ostringstream msg;
msg << om << i << ", " << j << ", " << k;
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
b = mx::random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device);
shape[k] = a.shape(k);
}
}
@@ -113,83 +114,83 @@ void time_irregular_binary_ops_4D() {
}
void time_irregular_reshape() {
auto device = default_device();
auto device = mx::default_device();
std::vector<int> shape;
auto reshape_fn = [&shape, device](const array& a) {
return reshape(a, shape, device);
auto reshape_fn = [&shape, device](const mx::array& a) {
return mx::reshape(a, shape, device);
};
int size = 64;
int d = 2 * size;
auto a = random::uniform({d, d, d});
auto a = mx::random::uniform({d, d, d});
shape = {8 * size, size, size};
TIMEM("3D contiguous", reshape_fn, a);
a = transpose(a);
a = mx::transpose(a);
shape = {8 * size, size, size};
TIMEM("3D transpose", reshape_fn, a);
TIMEM("3D mx::transpose", reshape_fn, a);
a = transpose(a, {1, 2, 0});
a = mx::transpose(a, {1, 2, 0});
shape = {8 * size, size, size};
TIMEM("3D transpose dims 1 2", reshape_fn, a);
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
a = broadcast_to(random::uniform({d, d}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
TIMEM("3D broadcast dim 0", reshape_fn, a);
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
TIMEM("3D broadcast dim 1", reshape_fn, a);
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
TIMEM("3D broadcast dim 2", reshape_fn, a);
a = broadcast_to(random::uniform({d}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
}
void time_irregular_astype_1D() {
auto device = default_device();
auto device = mx::default_device();
int size = 1000000;
int step = 2;
auto a = random::uniform({size});
auto a = mx::random::uniform({size});
a = slice(a, {0}, {size}, {step});
TIMEM("1D strided", astype, a, int32, device);
TIMEM("1D strided", mx::astype, a, mx::int32, device);
}
void time_irregular_astype_2D() {
auto device = default_device();
auto device = mx::default_device();
int size = 2048;
std::vector<int> shape = {size, size};
auto a = random::uniform(shape);
TIMEM("2D regular", astype, a, int32, device);
auto a = mx::random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device);
a = transpose(a);
TIMEM("2D transpose", astype, a, int32, device);
a = mx::transpose(a);
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
a = broadcast_to(random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
a = broadcast_to(random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
}
int main(int argc, char** argv) {
if (argc > 1) {
bool use_gpu = !strcmp(argv[1], "gpu");
set_default_device(use_gpu ? Device::gpu : Device::cpu);
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
}
std::cout << "Benchmarks for " << default_device() << std::endl;
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
time_irregular_binary_ops_1D();
time_irregular_binary_ops_2D();
time_irregular_binary_ops_3D();

View File

@@ -3,20 +3,20 @@
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
namespace mx = mlx::core;
void time_creation_ops() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto full_fp32 = [&]() { return full(shape, 3.3f); };
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
TIME(full_fp32);
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
TIME(zeros_fp32);
auto ones_fp32 = [&]() { return ones(shape, float32); };
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
TIME(ones_fp32);
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
TIME(arange_fp32);
}
@@ -24,188 +24,196 @@ void time_type_conversions() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto device = default_device();
auto device = mx::default_device();
auto a = zeros(shape, float32);
eval(a);
TIMEM("float32 to int32", astype, a, int32, device);
TIMEM("float32 to uint32", astype, a, uint32, device);
auto a = mx::zeros(shape, mx::float32);
mx::eval(a);
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
a = zeros(shape, int32);
eval(a);
TIMEM("int32 to float32", astype, a, float32, device);
a = mx::zeros(shape, mx::int32);
mx::eval(a);
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
a = zeros(shape, bool_);
eval(a);
TIMEM("bool to float32", astype, a, float32, device);
TIMEM("bool to int32", astype, a, int32, device);
TIMEM("bool to uint32", astype, a, uint32, device);
a = mx::zeros(shape, mx::bool_);
mx::eval(a);
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
}
void time_random_generation() {
int M = 2000;
int N = 500;
auto uniform = [&]() { return random::uniform({M, N}, float32); };
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
TIME(uniform);
auto normal = [&]() { return random::normal({M, N}, float32); };
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
TIME(normal);
}
void time_unary_ops() {
int M = 2000;
int N = 500;
auto device = default_device();
auto device = mx::default_device();
auto a = random::normal({M, N});
eval(a);
auto a = mx::random::normal({M, N});
mx::eval(a);
TIME(mlx::core::abs, a, device);
TIME(negative, a, device);
TIME(sign, a, device);
TIME(square, a, device);
TIME(mx::negative, a, device);
TIME(mx::sign, a, device);
TIME(mx::square, a, device);
TIME(mlx::core::sqrt, a, device);
TIME(rsqrt, a, device);
TIME(mx::rsqrt, a, device);
TIME(mlx::core::exp, a, device);
a = random::uniform({M, N});
a = mx::random::uniform({M, N});
TIME(mlx::core::log, a, device);
}
void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
auto condition = mx::random::randint(0, 2, {M, N, K});
auto a = mx::random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K});
auto device = mx::default_device();
mx::eval(a, b);
TIME(add, a, b, device);
TIME(subtract, a, b, device);
TIME(multiply, a, b, device);
TIME(divide, a, b, device);
TIME(maximum, a, b, device);
TIME(minimum, a, b, device);
TIME(mx::add, a, b, device);
TIME(mx::subtract, a, b, device);
TIME(mx::multiply, a, b, device);
TIME(mx::divide, a, b, device);
TIME(mx::maximum, a, b, device);
TIME(mx::minimum, a, b, device);
TIME(mx::where, condition, a, b, device);
b = random::uniform({1});
eval(b);
TIMEM("scalar", add, a, b, device);
TIMEM("vector-scalar", subtract, a, b, device);
TIMEM("scalar-vector", subtract, b, a, device);
TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device);
condition = mx::array({true});
b = mx::random::uniform({1});
mx::eval(b);
TIMEM("scalar", mx::add, a, b, device);
TIMEM("vector-scalar", mx::subtract, a, b, device);
TIMEM("scalar-vector", mx::subtract, b, a, device);
TIMEM("scalar", mx::multiply, a, b, device);
TIMEM("vector-scalar", mx::divide, a, b, device);
TIMEM("scalar-vector", mx::divide, b, a, device);
TIMEM("scalar-vector", mx::where, condition, a, b, device);
a = broadcast_to(random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100});
eval(a, b);
TIMEM("scalar-scalar broadcast", add, a, b, device);
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device);
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
mx::eval(a, b);
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
}
void time_strided_ops() {
int M = 50, N = 50, O = 50, P = 50;
auto a = random::uniform({M, N, O, P});
auto b = random::uniform({M, N, O, P});
auto device = default_device();
eval(a, b);
TIMEM("non-strided", add, a, b, device);
a = transpose(a, {1, 0, 2, 3});
b = transpose(b, {3, 2, 0, 1});
eval(a, b);
TIMEM("strided", add, a, b, device);
auto a = mx::random::uniform({M, N, O, P});
auto b = mx::random::uniform({M, N, O, P});
auto device = mx::default_device();
mx::eval(a, b);
TIMEM("non-strided", mx::add, a, b, device);
a = mx::transpose(a, {1, 0, 2, 3});
b = mx::transpose(b, {3, 2, 0, 1});
mx::eval(a, b);
TIMEM("strided", mx::add, a, b, device);
}
void time_comparisons() {
int M = 1000, N = 100, K = 10;
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
TIME(equal, a, b, device);
TIME(greater, a, b, device);
TIME(greater_equal, a, b, device);
TIME(less, a, b, device);
TIME(less_equal, a, b, device);
auto a = mx::random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K});
auto device = mx::default_device();
mx::eval(a, b);
TIME(mx::equal, a, b, device);
TIME(mx::greater, a, b, device);
TIME(mx::greater_equal, a, b, device);
TIME(mx::less, a, b, device);
TIME(mx::less_equal, a, b, device);
}
void time_matvec() {
int M = 2000, N = 200;
auto a = random::uniform({M, N});
auto b = random::uniform({N});
auto c = random::uniform({M});
eval(a, b, c);
auto matvec = [&]() { return matmul(a, b); };
auto a = mx::random::uniform({M, N});
auto b = mx::random::uniform({N});
auto c = mx::random::uniform({M});
mx::eval(a, b, c);
auto matvec = [&]() { return mx::matmul(a, b); };
TIME(matvec);
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
TIME(matvec_transpose);
}
void time_matmul() {
int M = 1000, N = 1000, K = 1000;
auto a = random::uniform({M, K});
auto b = random::uniform({K, N});
auto device = default_device();
eval(a, b);
TIME(matmul, a, b, device);
auto a = mx::random::uniform({M, K});
auto b = mx::random::uniform({K, N});
auto device = mx::default_device();
mx::eval(a, b);
TIME(mx::matmul, a, b, device);
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
TIME(transpose_matmul);
}
void time_reductions() {
auto a = random::normal({10000, 1000});
eval(a);
auto sum_all = [&a]() { return sum(a, false); };
auto a = mx::random::normal({10000, 1000});
mx::eval(a);
auto sum_all = [&a]() { return mx::sum(a, false); };
TIME(sum_all);
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
TIME(sum_along_0);
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
TIME(sum_along_1);
auto prod_all = [&a]() { return prod(a, false); };
auto prod_all = [&a]() { return mx::prod(a, false); };
TIME(prod_all);
auto all_true = [&a]() { return all(a, false); };
auto all_true = [&a]() { return mx::all(a, false); };
TIME(all_true);
auto all_along_0 = [&a]() { return all(a, 0, false); };
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
TIME(all_along_0);
auto all_along_1 = [&a]() { return all(a, 1, false); };
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
TIME(all_along_1);
auto any_true = [&a]() { return any(a, false); };
auto any_true = [&a]() { return mx::any(a, false); };
TIME(any_true);
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
TIME(argmin_along_0);
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
TIME(argmin_along_1);
}
void time_gather_scatter() {
auto a = random::normal({1000, 768});
eval(a);
auto indices = random::randint(0, 1000, {256});
eval(indices);
auto a = mx::random::normal({1000, 768});
mx::eval(a);
auto indices = mx::random::randint(0, 1000, {256});
mx::eval(indices);
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
TIME(embedding_lookup);
indices = random::randint(0, 768 * 1000, {256 * 768});
eval(indices);
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
mx::eval(indices);
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
auto single_element_lookup = [&a, &indices]() {
return mx::take(a, indices);
};
TIME(single_element_lookup);
indices = random::randint(0, 1000, {256});
auto updates = random::normal({256, 1, 768});
eval(indices, updates);
indices = mx::random::randint(0, 1000, {256});
auto updates = mx::random::normal({256, 1, 768});
mx::eval(indices, updates);
auto embedding_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
@@ -217,10 +225,10 @@ void time_gather_scatter() {
};
TIME(embedding_add);
a = reshape(a, {-1});
indices = random::randint(0, 768 * 1000, {768 * 256});
updates = random::normal({256 * 768, 1});
eval(a, indices, updates);
a = mx::reshape(a, {-1});
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
updates = mx::random::normal({256 * 768, 1});
mx::eval(a, indices, updates);
auto single_element_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
@@ -233,8 +241,22 @@ void time_gather_scatter() {
TIME(single_element_add);
}
void time_divmod() {
auto a = mx::random::normal({1000});
auto b = mx::random::normal({1000});
mx::eval({a, b});
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
TIME(divmod_fused);
auto divmod_separate = [&a, &b]() {
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
};
TIME(divmod_separate);
}
int main() {
std::cout << "Benchmarks for " << default_device() << std::endl;
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
time_creation_ops();
time_type_conversions();
time_unary_ops();
@@ -246,4 +268,5 @@ int main() {
time_matmul();
time_reductions();
time_gather_scatter();
time_divmod();
}

View File

@@ -17,14 +17,13 @@
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " \
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
<< std::flush << std::setprecision(5) \
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
template <typename F, typename... Args>
double time_fn(F fn, Args... args) {
double time_fn(F fn, Args&&... args) {
// warmup
for (int i = 0; i < 5; ++i) {
eval(fn(std::forward<Args>(args)...));

View File

@@ -1,7 +1,6 @@
# Copyright © 2023 Apple Inc.
import numpy as np
from time_utils import time_fn

View File

@@ -1,8 +1,8 @@
# Copyright © 2023 Apple Inc.
import argparse
import mlx.core as mx
import mlx.core as mx
from time_utils import time_fn
B = 8
@@ -30,7 +30,7 @@ def time_batch_matmul():
time_fn(batch_vjp_second)
def time_unbatch_matmul(key):
def time_unbatch_matmul():
mx.random.seed(3)
a = mx.random.uniform(shape=(B * T, D))
b = mx.random.uniform(shape=(D, D))

View File

@@ -1,13 +1,14 @@
# Copyright © 2023 Apple Inc.
import numpy as np
import argparse
import mlx.core as mx
import time
import torch
import os
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
@@ -165,13 +166,13 @@ if __name__ == "__main__":
dtypes = ("float32", "float16")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 234, 768, 3072),
(1, 64, 64, 25344),
(16, 1024, 1024, 1024),
(1, 1024, 1024, 2048),
(4, 1024, 1024, 4096),
(4, 1024, 4096, 1024),
(1, 4096, 4096, 4096),
(15, 1023, 1023, 1023),
(17, 1025, 1025, 1025),
)
for dtype in dtypes:

View File

@@ -1,14 +1,14 @@
# Copyright © 2023 Apple Inc.
import matplotlib.pyplot as plt
import numpy as np
import argparse
import mlx.core as mx
import time
import torch
import os
import subprocess
import time
import matplotlib.pyplot as plt
import mlx.core as mx
import numpy as np
import torch
results_dir = "./results"
@@ -133,7 +133,7 @@ def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose):
np_dtype = getattr(np, dtype)
mlx_gb_s = []
mlx_gflops = []
@@ -164,7 +164,7 @@ def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
ax.legend()
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose):
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
np_dtype = getattr(np, dtype)
mlx_gb_s = []
mlx_gflops = []

View File

@@ -4,8 +4,10 @@ import argparse
import math
import os
import time
from functools import partial
import mlx.core as mx
import mlx.nn as nn
def int_or_list(x):
@@ -22,6 +24,16 @@ def none_or_list(x):
return [int(xi) for xi in x.split(",")]
def dtype_from_str(x):
if x == "":
return mx.float32
else:
dt = getattr(mx, x)
if not isinstance(dt, mx.Dtype):
raise ValueError(f"{x} is not an mlx dtype")
return dt
def bench(f, *args):
for i in range(10):
f(*args)
@@ -48,6 +60,63 @@ def matmul(x, y):
mx.eval(ys)
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
ys = []
for i in range(10):
ys.append(
mx.quantized_matmul(
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
)
)
mx.eval(ys)
quant_matmul = {
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
"quant_matmul_128_2": partial(
_quant_matmul, transpose=False, group_size=128, bits=2
),
"quant_matmul_128_4": partial(
_quant_matmul, transpose=False, group_size=128, bits=4
),
"quant_matmul_128_8": partial(
_quant_matmul, transpose=False, group_size=128, bits=8
),
"quant_matmul_t_32_2": partial(
_quant_matmul, transpose=True, group_size=32, bits=2
),
"quant_matmul_t_32_4": partial(
_quant_matmul, transpose=True, group_size=32, bits=4
),
"quant_matmul_t_32_8": partial(
_quant_matmul, transpose=True, group_size=32, bits=8
),
"quant_matmul_t_64_2": partial(
_quant_matmul, transpose=True, group_size=64, bits=2
),
"quant_matmul_t_64_4": partial(
_quant_matmul, transpose=True, group_size=64, bits=4
),
"quant_matmul_t_64_8": partial(
_quant_matmul, transpose=True, group_size=64, bits=8
),
"quant_matmul_t_128_2": partial(
_quant_matmul, transpose=True, group_size=128, bits=2
),
"quant_matmul_t_128_4": partial(
_quant_matmul, transpose=True, group_size=128, bits=4
),
"quant_matmul_t_128_8": partial(
_quant_matmul, transpose=True, group_size=128, bits=8
),
}
def conv1d(x, y):
ys = []
for i in range(10):
@@ -75,6 +144,13 @@ def reduction(op, axis, x):
mx.eval(ys)
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
mx.eval(z)
def softmax(axis, x):
ys = []
for i in range(100):
@@ -95,7 +171,77 @@ def softmax_fused(axis, x):
def relu(x):
y = x
for i in range(100):
y = mx.maximum(y, 0)
y = nn.relu(y)
mx.eval(y)
def leaky_relu(x: mx.array):
y = x
for i in range(100):
y = nn.leaky_relu(y)
mx.eval(y)
def prelu(x: mx.array):
y = x
for i in range(100):
y = nn.prelu(y, mx.ones(1))
mx.eval(y)
def softplus(x: mx.array):
y = x
for i in range(100):
y = nn.softplus(y)
mx.eval(y)
def mish(x: mx.array):
y = x
for i in range(100):
y = nn.mish(y)
mx.eval(y)
def leaky_relu(x):
y = x
for i in range(100):
y = nn.leaky_relu(y)
mx.eval(y)
def elu(x):
y = x
for i in range(100):
y = nn.elu(y)
mx.eval(y)
def relu6(x):
y = x
for i in range(100):
y = nn.relu6(y)
mx.eval(y)
def softplus(x):
y = x
for i in range(100):
y = nn.softplus(y)
mx.eval(y)
def celu(x):
y = x
for i in range(100):
y = nn.celu(y)
mx.eval(y)
def log_sigmoid(x):
y = x
for i in range(100):
y = nn.log_sigmoid(y)
mx.eval(y)
@@ -130,6 +276,13 @@ def linear(w, b, x):
mx.eval(ys)
def linear_fused(w, b, x):
ys = []
for i in range(10):
ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
mx.eval(ys)
def rope(x):
*_, N, D = x.shape
ys = []
@@ -180,6 +333,20 @@ def topk(axis, x):
mx.eval(ys)
def step_function(x):
y = x
for i in range(100):
y = nn.step(x)
mx.eval(y)
def selu(x):
y = x
for i in range(100):
y = nn.selu(x)
mx.eval(y)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("benchmark", help="Choose the benchmark to run")
@@ -211,9 +378,7 @@ if __name__ == "__main__":
parser.add_argument(
"--fused", action="store_true", help="Use fused functions where possible"
)
parser.add_argument(
"--dtype", choices=["float32", "float16", "bfloat16"], default="float32"
)
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
args = parser.parse_args()
@@ -222,19 +387,19 @@ if __name__ == "__main__":
if len(args.axis) > 1:
args.axis.pop(0)
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.cpu:
mx.set_default_device(mx.cpu)
else:
mx.set_default_device(mx.gpu)
dtype = dict(float32=mx.float32, float16=mx.float16, bfloat16=mx.bfloat16)[
args.dtype
]
types = args.dtype
if not types:
types = [mx.float32]
if len(types) < len(args.size):
types = types + [types[0]] * (len(args.size) - len(types))
xs = []
for size in args.size:
for size, dtype in zip(args.size, types):
xs.append(mx.random.normal(size).astype(dtype))
for i, t in enumerate(args.transpose):
if t is None:
@@ -244,14 +409,24 @@ if __name__ == "__main__":
x = xs[0]
axis = args.axis[0]
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))
elif args.benchmark == "matmul":
print(bench(matmul, *xs))
elif args.benchmark.startswith("quant_matmul"):
print(bench(quant_matmul[args.benchmark], *xs))
elif args.benchmark == "linear":
print(bench(linear, *xs))
if args.fused:
print(bench(linear_fused, *xs))
else:
print(bench(linear, *xs))
elif args.benchmark == "sum_axis":
print(bench(reduction, "sum", axis, x))
@@ -277,6 +452,26 @@ if __name__ == "__main__":
elif args.benchmark == "relu":
print(bench(relu, x))
elif args.benchmark == "elu":
print(bench(elu, x))
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "celu":
print(bench(celu, x))
elif args.benchmark == "log_sigmoid":
print(bench(log_sigmoid, x))
elif args.benchmark == "leaky_relu":
print(bench(leaky_relu, x))
elif args.benchmark == "prelu":
print(bench(prelu, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))
elif args.benchmark == "mish":
print(bench(mish, x))
elif args.benchmark == "scalar_mul":
print(bench(scalar_mult, x))
@@ -311,5 +506,14 @@ if __name__ == "__main__":
elif args.benchmark == "topk":
print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else:
raise ValueError("Unknown benchmark")

View File

@@ -22,6 +22,16 @@ def none_or_list(x):
return [int(xi) for xi in x.split(",")]
def dtype_from_str(x):
if x == "":
return torch.float32
else:
dt = getattr(torch, x)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{x} is not a torch dtype")
return dt
def bench(f, *args):
for i in range(10):
f(*args)
@@ -115,6 +125,70 @@ def relu(x):
sync_if_needed(x)
@torch.no_grad()
def leaky_relu(x):
y = x
for i in range(100):
y = torch.nn.functional.leaky_relu(y)
sync_if_needed(x)
@torch.no_grad()
def elu(x):
y = x
for i in range(100):
y = torch.nn.functional.elu(y)
sync_if_needed(x)
@torch.no_grad()
def celu(x):
y = x
for i in range(100):
y = torch.nn.functional.celu(y)
sync_if_needed(x)
@torch.no_grad()
def relu6(x):
y = x
for i in range(100):
y = torch.nn.functional.relu6(y)
sync_if_needed(x)
@torch.no_grad()
def softplus(x):
y = x
for i in range(100):
y = torch.nn.functional.softplus(y)
sync_if_needed(x)
@torch.no_grad()
def log_sigmoid(x):
y = x
for i in range(100):
y = torch.nn.functional.logsigmoid(y)
sync_if_needed(x)
@torch.no_grad()
def prelu(x: torch.Tensor) -> torch.Tensor:
y = x
for _ in range(100):
y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device))
sync_if_needed(x)
@torch.no_grad()
def mish(x: torch.Tensor) -> torch.Tensor:
y = x
for _ in range(100):
y = torch.nn.functional.mish(y)
sync_if_needed(x)
@torch.no_grad()
def scalar_mult(x):
y = x
@@ -209,6 +283,22 @@ def topk(axis, x):
sync_if_needed(x)
@torch.no_grad()
def step_function(x):
y = x
for i in range(100):
y = torch.where(y < 0, 0, 1)
sync_if_needed(x)
@torch.no_grad()
def selu(x):
y = x
for i in range(100):
y = torch.nn.functional.selu(y)
sync_if_needed(x)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("benchmark", help="Choose the benchmark to run")
@@ -240,7 +330,7 @@ if __name__ == "__main__":
parser.add_argument(
"--fused", action="store_true", help="Use fused functions where possible"
)
parser.add_argument("--dtype", choices=["float32", "float16"], default="float32")
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
args = parser.parse_args()
@@ -249,15 +339,17 @@ if __name__ == "__main__":
if len(args.axis) > 1:
args.axis.pop(0)
if args.print_pid:
print(os.getpid())
input("Press enter to run")
torch.set_num_threads(1)
device = "cpu" if args.cpu else "mps"
dtype = dict(float32=torch.float32, float16=torch.float16)[args.dtype]
types = args.dtype
if not types:
types = [torch.float32]
if len(types) < len(args.size):
types = types + [types[0]] * (len(args.size) - len(types))
xs = []
for size in args.size:
for size, dtype in zip(args.size, types):
xs.append(torch.randn(*size).to(device).to(dtype))
for i, t in enumerate(args.transpose):
if t is None:
@@ -266,6 +358,10 @@ if __name__ == "__main__":
x = xs[0]
axis = args.axis[0]
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))
@@ -302,6 +398,28 @@ if __name__ == "__main__":
elif args.benchmark == "relu":
print(bench(relu, x))
elif args.benchmark == "leaky_relu":
print(bench(leaky_relu, x))
elif args.benchmark == "elu":
print(bench(elu, x))
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))
elif args.benchmark == "celu":
print(bench(celu, x))
elif args.benchmark == "log_sigmoid":
print(bench(log_sigmoid, x))
elif args.benchmark == "prelu":
print(bench(prelu, x))
elif args.benchmark == "mish":
print(bench(mish, x))
elif args.benchmark == "scalar_mul":
print(bench(scalar_mult, x))
@@ -336,5 +454,11 @@ if __name__ == "__main__":
elif args.benchmark == "topk":
print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
else:
raise ValueError("Unknown benchmark")
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")

View File

@@ -16,7 +16,9 @@ def run_or_raise(*args, **kwargs):
result = run(*args, capture_output=True, **kwargs)
return float(result.stdout)
except ValueError:
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
raise ValueError(
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
)
def compare(args):
@@ -62,7 +64,7 @@ def make_predicate(positive_filter, negative_filter):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run comparisons agains PyTorch")
parser = argparse.ArgumentParser(description="Run comparisons against PyTorch")
parser.add_argument(
"--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
)
@@ -80,10 +82,8 @@ if __name__ == "__main__":
_filter = make_predicate(args.filter, args.negative_filter)
if args.mlx_dtypes:
compare_filtered = (
lambda x: compare_mlx_dtypes(
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
)
compare_filtered = lambda x: (
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
if _filter(x)
else None
)
@@ -125,6 +125,14 @@ if __name__ == "__main__":
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1")
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
compare_filtered("argmax --size 10x1024x128 --axis 1")
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
@@ -193,6 +201,27 @@ if __name__ == "__main__":
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu")
compare_filtered("relu --size 32x16x1024")
compare_filtered("relu --size 32x16x1024 --cpu")
compare_filtered("leaky_relu --size 32x16x1024")
compare_filtered("leaky_relu --size 32x16x1024 --cpu")
compare_filtered("elu --size 32x16x1024")
compare_filtered("elu --size 32x16x1024 --cpu")
compare_filtered("relu6 --size 32x16x1024")
compare_filtered("relu6 --size 32x16x1024 --cpu")
compare_filtered("softplus --size 32x16x1024")
compare_filtered("softplus --size 32x16x1024 --cpu")
compare_filtered("celu --size 32x16x1024")
compare_filtered("celu --size 32x16x1024 --cpu")
compare_filtered("log_sigmoid --size 32x16x1024")
compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
compare_filtered("step --size 32x16x1024")
compare_filtered("step --size 32x16x1024 --cpu")
compare_filtered("selu --size 32x16x1024")
compare_filtered("selu --size 32x16x1024 --cpu")
# compare_filtered("mish --size 32x16x1024") NOTE: Torch does not implement Mish in MPS atm
compare_filtered("mish --size 32x16x1024 --cpu")
compare_filtered("prelu --size 32x16x1024")
compare_filtered("prelu --size 32x16x1024 --cpu")
compare_filtered("scalar_mul --size 32x16x1024")
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
compare_filtered("cross_entropy --size 256x1024")

View File

@@ -0,0 +1,107 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import math
import random
import mlx.core as mx
from time_utils import time_fn
def bench_gelu():
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
x = mx.random.uniform(shape=(1000, 1024))
def gen_fun(fun):
def bench_fun(x):
for _ in range(10):
x = fun(x)
return x
return bench_fun
time_fn(gen_fun(gelu), x, msg="fixed gelu")
time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu")
def randint():
return random.randint(1, x.shape[0])
def gen_fun(fun):
def bench_fun(x, y):
x = x[: randint()]
for _ in range(10):
x = fun(x)
y = fun(y)
return x, y
return bench_fun
y = mx.random.uniform(shape=(1000, 1024))
time_fn(gen_fun(gelu), x, y, msg="variable gelu")
time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu")
time_fn(
gen_fun(mx.compile(gelu, shapeless=True)),
x,
y,
msg="shapeless variable gelu",
)
def bench_layernorm():
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
mx.eval(weight, bias)
def layernorm(x):
x = x.astype(mx.float32)
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + 1e-4)
x = x.astype(mx.float16)
return weight * x + bias
x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16)
def gen_fun(fun):
def bench_fun(x):
for _ in range(10):
x = fun(x)
return x
return bench_fun
time_fn(gen_fun(layernorm), x, msg="fixed layernorm")
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm")
def randint():
return random.randint(1, x.shape[0])
def gen_fun(fun):
def bench_fun(x):
x = x[: randint()]
for _ in range(10):
x = fun(x)
return x
return bench_fun
random.seed(0)
time_fn(gen_fun(layernorm), x, msg="variable layernorm")
random.seed(0)
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm")
random.seed(0)
time_fn(
gen_fun(mx.compile(layernorm, shapeless=True)),
x,
msg="shapeless variable layernorm",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Compile benchmarks.")
args = parser.parse_args()
bench_gelu()
bench_layernorm()

View File

@@ -0,0 +1,123 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_1D(strides=1, padding=0, groups=1):
def mx_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_1D
def make_pt_conv_1D(strides=1, padding=0, groups=1):
@torch.no_grad()
def pt_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_1D
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
scale = 1.0 / math.sqrt(wH * C)
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_1D(strides, padding, groups)
f_pt = make_pt_conv_1D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv1d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 5, 32, 1, 2, 1),
(4, 32, 32, 5, 32, 1, 2, 2),
(4, 32, 32, 5, 32, 1, 2, 4),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 16),
(4, 32, 32, 5, 32, 1, 2, 32),
(4, 32, 256, 5, 512, 1, 2, 2),
(4, 32, 256, 5, 512, 1, 2, 128),
(4, 32, 256, 5, 512, 1, 2, 256),
)
for dtype in dtypes:
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
for N, iH, C, wH, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, iH, C, wH, O, strides, padding, np_dtype, groups
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,127 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,143 @@
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal([10, 256, 256, 3])
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(10, 3, 256, 256, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 20
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,129 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,110 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,143 @@
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal(shape)
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(*shape, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 10
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,116 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose3d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,135 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,135 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,107 @@
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
dtype = "float32"
shapes = (
(4, 32, 32, 21, 3, 3, 128),
(4, 32, 32, 21, 3, 3, 37),
(4, 32, 32, 370, 3, 3, 370),
(4, 32, 32, 370, 7, 7, 128),
(2, 320, 640, 21, 7, 7, 21),
)
for N, H, W, C, kh, kw, O in shapes:
time_mlx, time_torch = bench_shape(
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,66 @@
# Copyright © 2024 Apple Inc.
"""
Run with:
mpirun -n 2 python /path/to/distributed_bench.py
"""
import time
import mlx.core as mx
def time_fn(fn, *args, **kwargs):
msg = kwargs.pop("msg", None)
world = mx.distributed.init()
if world.rank() == 0:
if msg:
print(f"Timing {msg} ...", end=" ")
else:
print(f"Timing {fn.__name__} ...", end=" ")
# warmup
for _ in range(5):
mx.eval(fn(*args, **kwargs))
num_iters = 100
tic = time.perf_counter()
for _ in range(num_iters):
x = mx.eval(fn(*args, **kwargs))
toc = time.perf_counter()
msec = 1e3 * (toc - tic) / num_iters
if world.rank() == 0:
print(f"{msec:.5f} msec")
def time_all_sum():
shape = (4096,)
x = mx.random.uniform(shape=shape)
mx.eval(x)
def sine(x):
for _ in range(20):
x = mx.sin(x)
return x
time_fn(sine, x)
def all_sum_plain(x):
for _ in range(20):
x = mx.distributed.all_sum(x)
return x
time_fn(all_sum_plain, x)
def all_sum_with_sine(x):
for _ in range(20):
x = mx.sin(x)
x = mx.distributed.all_sum(x)
return x
time_fn(all_sum_with_sine, x)
if __name__ == "__main__":
time_all_sum()

View File

@@ -0,0 +1,84 @@
# Copyright © 2024 Apple Inc.
import time
import mlx.core as mx
import numpy as np
def timeit(fn, its=100, args=[]):
for _ in range(5):
fn(*args)
tic = time.perf_counter()
for _ in range(its):
fn(*args)
toc = time.perf_counter()
return 1e3 * (toc - tic) / its
def time_little_einsum_path():
subscripts = "ik,kj->ij"
x = mx.ones((32, 32))
y = mx.ones((32, 32))
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
x = np.array(x)
y = np.array(y)
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
print("Timing little einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_big_einsum_path():
chars = list("abcdefgh")
char_to_dim = {c: v for v, c in enumerate(chars)}
num_inputs = 10
inputs = []
subscripts = []
for _ in range(num_inputs):
subscript = np.random.choice(chars, size=5, replace=False).tolist()
subscripts.append("".join(subscript))
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
subscripts = ",".join(subscripts)
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
inputs = [mx.array(x) for x in inputs]
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
print("Timing big einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_attention():
def regular_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
scores = mx.softmax(scores, axis=-1)
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
mx.eval(output)
def einsum_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
scores = mx.softmax(scores, axis=-1)
output = mx.einsum("ijtu,iujk->itjk", scores, values)
mx.eval(output)
x = mx.random.uniform(shape=(8, 512, 32, 128))
regular_time = timeit(regular_attention, args=(x,))
ein_time = timeit(einsum_attention, args=(x,))
print("Timing einsum attention...")
print(f"Regular ... {regular_time:.3f} ms")
print(f"Einsum ... {ein_time:.3f} ms")
if __name__ == "__main__":
time_little_einsum_path()
time_big_einsum_path()
time_attention()

View File

@@ -0,0 +1,118 @@
# Copyright © 2024 Apple Inc.
import matplotlib
import mlx.core as mx
import numpy as np
import sympy
import torch
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def bandwidth_gb(runtime_ms, system_size):
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
bytes_per_gb = 1e9
ms_per_s = 1e3
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
def fft_mlx(x):
if dim == 1:
out = mx.fft.fft(x)
elif dim == 2:
out = mx.fft.fft2(x)
mx.eval(out)
return out
def fft_mps(x):
if dim == 1:
out = torch.fft.fft(x)
elif dim == 2:
out = torch.fft.fft2(x)
torch.mps.synchronize()
return out
bandwidths = []
for n in fft_sizes:
batch_size = system_size // n**dim
shape = [batch_size] + [n for _ in range(dim)]
if backend == "mlx":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = mx.array(x_np)
mx.eval(x)
fft = fft_mlx
elif backend == "mps":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = torch.tensor(x_np, device="mps")
torch.mps.synchronize()
fft = fft_mps
else:
raise NotImplementedError()
runtime_ms = measure_runtime(fft, x=x)
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
print(n, bandwidth)
bandwidths.append(bandwidth)
return np.array(bandwidths)
def time_fft():
x = np.array(range(2, 512))
system_size = int(2**26)
print("MLX GPU")
with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
print("MPS GPU")
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
print("CPU")
system_size = int(2**20)
with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
x = np.array(x)
all_indices = x - x[0]
radix_2to13 = (
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
)
bluesteins = (
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
)
for indices, name in [
(all_indices, "All"),
(radix_2to13, "Radix 2-13"),
(bluesteins, "Bluestein's"),
]:
# plot bandwidths
print(name)
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
plt.title(f"MLX FFT Benchmark -- {name}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"{name}.png")
plt.clf()
av_gpu_bandwidth = np.mean(gpu_bandwidths)
av_mps_bandwidth = np.mean(mps_bandwidths)
av_cpu_bandwidth = np.mean(cpu_bandwidths)
print("Average bandwidths:")
print("GPU:", av_gpu_bandwidth)
print("MPS:", av_mps_bandwidth)
print("CPU:", av_cpu_bandwidth)
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
print("Percent MLX faster than MPS: ", portion_faster * 100)
if __name__ == "__main__":
time_fft()

View File

@@ -0,0 +1,52 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import mlx.core as mx
import torch
from time_utils import measure_runtime
def benchmark_gather_mlx(x_shape, idx_shape):
def gather(x, idx):
mx.eval(x[idx])
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
x = mx.random.normal(x_shape).astype(mx.float32)
runtime = measure_runtime(gather, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_gather_torch(x_shape, idx_shape, device):
def gather(x, idx, device):
_ = x[idx]
if device == torch.device("mps"):
torch.mps.synchronize()
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
x = torch.randn(x_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
idx_shapes = [(1_000_000,), (100_000,), ()]
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_gather_mlx(x_shape, idx_shape)
benchmark_gather_torch(x_shape, idx_shape, device=device)

View File

@@ -0,0 +1,74 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_mm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = x @ w1.T
x = x @ w2.T
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_mm()

View File

@@ -0,0 +1,84 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate(
[
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
for i, j in enumerate(idx.tolist())
],
axis=0,
)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_qmm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = mx.quantized_matmul(x, *w1, transpose=True)
x = mx.quantized_matmul(x, *w2, transpose=True)
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_qmm()

View File

@@ -0,0 +1,70 @@
import argparse
import matplotlib
import mlx.core as mx
import numpy as np
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def had(x):
y = mx.hadamard_transform(x)
mx.eval(y)
def copy(x):
y = x + 1.0
mx.eval(y)
def run(dtype):
system_size = 2**26
outputs = {}
for test_fn in (had, copy):
for m in [1, 12, 20, 28]:
if test_fn == copy:
key = "copy"
elif m == 1:
key = "had_2^k"
else:
key = "had_m*2^k"
outputs.setdefault(key, {})
for k in range(7, 14):
n = m * 2**k
if n > 2**15:
continue
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
x = mx.array(x_np)
runtime_ms = measure_runtime(test_fn, x=x)
bytes_per_gb = 1e9
ms_per_s = 1e3
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
bandwidth_gb = (
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
)
print(n, bandwidth_gb)
outputs[key][n] = bandwidth_gb
colors = {
"copy": "black",
"had_2^k": "steelblue",
"had_m*2^k": "skyblue",
}
for key, output in outputs.items():
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"bench_{dtype.__name__}.png")
plt.clf()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
dtype = np.float16 if args.fp16 else np.float32
run(dtype)

View File

@@ -0,0 +1,82 @@
# Copyright © 2023-2024 Apple Inc.
from functools import partial
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def layer_norm(x, w, b, eps):
ot = x.dtype
x = x.astype(mx.float32)
mu = mx.mean(x, -1, keepdims=True)
v = mx.var(x, -1, keepdims=True)
y = (x - mu) * mx.rsqrt(v + eps)
if w is not None:
y = y * w
if b is not None:
y = y + b
return y
def time_layer_norm(N, dt):
L = 1024
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
mx.eval(x, w, b, y)
def layer_norm_loop(f, x, w, b):
for _ in range(32):
x = f(x, w, b)
return x
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
def layer_norm_grad_loop(g, x, w, b):
gx, gw, gb = x, w, b
for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb
time_fn(layer_norm_grad_loop, g1, x, w, b)
time_fn(layer_norm_grad_loop, g2, x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
mx.eval(x, w, b, y)
def layer_norm_grad_x_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(layer_norm_grad_x_loop, g1, x)
time_fn(layer_norm_grad_x_loop, g2, x)
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
if __name__ == "__main__":
for dt in [mx.float32, mx.float16, mx.bfloat16]:
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
print(dt, n)
time_layer_norm(n, dt)

View File

@@ -1,198 +0,0 @@
# Copyright © 2023 Apple Inc.
import math
import time
import jax
import jax.numpy as jnp
from flax import linen as nn
class RoPE(nn.Module):
dims: int
traditional: bool = False
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = jnp.concatenate([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return rx
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
dtype=jnp.float32,
):
D = D // 2
positions = jnp.arange(offset, N, dtype=dtype)
freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D))
theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1))
costheta = jnp.cos(theta)
sintheta = jnp.sin(theta)
return costheta, sintheta
@nn.compact
def __call__(self, x, offset: int = 0):
shape = x.shape
x = x.reshape((-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.reshape(shape)
class LlamaAttention(nn.Module):
dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
num_heads = self.num_heads
dims = self.dims
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = jnp.concatenate([key_cache, keys], axis=2)
values = jnp.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose((0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = jax.nn.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
dims: int
mlp_dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
dims = self.dims
mlp_dims = self.mlp_dims
num_heads = self.num_heads
self.attention = LlamaAttention(dims, num_heads, dtype)
self.norm1 = nn.RMSNorm(param_dtype=self.dtype)
self.norm2 = nn.RMSNorm(param_dtype=self.dtype)
self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = jax.nn.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
dtype = jnp.float16
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
x = jax.random.normal(k1, (1, 1, D), dtype)
cache = [
jax.random.normal(k2, [1, H, C, D // H], dtype),
jax.random.normal(k3, [1, H, C, D // H], dtype),
]
layer = LlamaEncoderLayer(D, F, H, dtype=dtype)
params = layer.init(k4, x, mask=None, cache=cache)["params"]
@jax.jit
def model_fn(x, mask, cache):
return layer.apply({"params": params}, x, mask=mask, cache=cache)
T = measure(model_fn, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -1,118 +0,0 @@
# Copyright © 2023 Apple Inc.
import math
import time
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, False)
self.key_proj = nn.Linear(dims, dims, False)
self.value_proj = nn.Linear(dims, dims, False)
self.out_proj = nn.Linear(dims, dims, False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, False)
self.linear2 = nn.Linear(dims, mlp_dims, False)
self.linear3 = nn.Linear(mlp_dims, dims, False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
mx.eval(y, c)
start = time.time()
rs = []
for i in range(5):
y, c = model(x, mask=None, cache=cache)
rs.append((y, c))
mx.eval(rs)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
mx.set_default_device(mx.gpu)
dtype = mx.float16
layer = LlamaEncoderLayer(D, F, H)
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
x = mx.random.normal([1, 1, D], dtype=dtype)
cache = [
mx.random.normal([1, H, C, D // H], dtype=dtype),
mx.random.normal([1, H, C, D // H], dtype=dtype),
]
mx.eval(x, cache)
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -1,199 +0,0 @@
# Copyright © 2023 Apple Inc.
import math
import time
import torch
import torch.nn as nn
import torch.mps
def sync_if_needed(x):
if x.device != torch.device("cpu"):
torch.mps.synchronize()
class RoPE(nn.Module):
def __init__(self, dims: int, traditional: bool = False):
super().__init__()
self.dims = dims
self.traditional = traditional
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1)
else:
rx = torch.cat([rx1, rx2], dim=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
return rx
def forward(self, x, offset: int = 0):
shape = x.shape
x = x.view(-1, shape[-2], shape[-1])
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, device=x.device, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.view(*shape)
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
device="cpu",
dtype=torch.float32,
):
D = D // 2
positions = torch.arange(offset, N, dtype=dtype, device=device)
freqs = torch.exp(
-torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D)
)
theta = positions.view(-1, 1) * freqs.view(1, -1)
costheta = torch.cos(theta)
sintheta = torch.sin(theta)
return costheta, sintheta
class RMSNorm(nn.Module):
def __init__(self, dims: int, epsilon: float = 1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones((dims,)))
self.epsilon = epsilon
def forward(self, x):
n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon)
return self.gamma * x * n
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def forward(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = torch.cat([key_cache, keys], dim=2)
values = torch.cat([value_cache, values], dim=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.permute(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = torch.softmax(scores, dim=-1)
values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = RMSNorm(dims)
self.norm2 = RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def forward(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = torch.nn.functional.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
@torch.no_grad()
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
device = torch.device("mps")
dtype = torch.float16
layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype)
x = torch.randn(1, 1, D).to(device).to(dtype)
cache = [
torch.randn(1, H, C, D // H).to(device).to(dtype),
torch.randn(1, H, C, D // H).to(device).to(dtype),
]
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -0,0 +1,63 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def rms_norm(x, w, eps):
ot = x.dtype
x = x.astype(mx.float32)
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
y = (x * n).astype(ot)
if w is not None:
y = y * w
return y
def time_rms_norm():
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1))
g2 = mx.grad(f2, argnums=(0, 1))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x, w):
gx, gw = x, w
for _ in range(32):
gx, gw = g(gx, gw, y)
return gx, gw
time_fn(rms_norm_loop, g1, x, w)
time_fn(rms_norm_loop, g2, x, w)
time_fn(rms_norm_loop, mx.compile(g1), x, w)
time_fn(rms_norm_loop, mx.compile(g2), x, w)
f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(rms_norm_loop, g1, x)
time_fn(rms_norm_loop, g2, x)
time_fn(rms_norm_loop, mx.compile(g1), x)
time_fn(rms_norm_loop, mx.compile(g2), x)
if __name__ == "__main__":
time_rms_norm()

View File

@@ -0,0 +1,35 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def time_rope():
rope = nn.RoPE(64)
# vec
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
mx.eval(x)
def rope_vec(x):
for _ in range(32):
x = rope(x, offset=100)
return x
time_fn(rope_vec, x)
# matrix
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
mx.eval(x)
def rope_mat(x):
for _ in range(32):
x = rope(x)
return x
time_fn(rope_mat, x)
if __name__ == "__main__":
time_rope()

View File

@@ -0,0 +1,96 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import mlx.core as mx
import torch
from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def scatter(dst, x, idx):
dst[tuple(idx)] = x
mx.eval(dst)
idx = []
for idx_shape in idx_shapes:
idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape))
x = mx.random.normal(x_shape).astype(mx.float32)
dst = mx.random.normal(dst_shape).astype(mx.float32)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
def scatter(dst, x, idx, device):
dst[tuple(idx)] = x
if device == torch.device("mps"):
torch.mps.synchronize()
idx = []
for idx_shape in idx_shapes:
idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device))
x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
dst_shapes = [
(10, 64),
(100_000, 64),
(1_000_000, 64),
(100_000,),
(200_000,),
(20_000_000,),
(10000, 64),
(100, 64),
(100, 10_000, 64),
(10, 100, 100, 21),
(1_000, 1_000, 10),
]
idx_shapes = [
[(1_000_000,)],
[(1_000_000,)],
[(100_000,)],
[(1_000_000,)],
[(20_000_000,)],
[(20_000_000,)],
[(1000000,)],
[(10000000,)],
[(1_000,)],
[(10_000,)],
[(1_000,), (1_000,)],
]
x_shapes = [
(1_000_000, 64),
(1_000_000, 64),
(100_000, 64),
(1_000_000,),
(20_000_000,),
(20_000_000,),
(1000000, 64),
(10000000, 64),
(1_000, 10_000, 64),
(10_000, 100, 100, 21),
(1_000, 10),
]
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20)
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

View File

@@ -0,0 +1,223 @@
# Copyright © 2024 Apple Inc.
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 5
N_iter_bench = 40
N_iter_func = 8
def bench(f, *args):
for i in range(N_warmup):
f(*args)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(*args)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
np_dtype = getattr(np, dtype)
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
scale = 1.0 / math.sqrt(D)
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
if mask is not None:
if mask == "additive":
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
mask = mx.array(mask_np)
elif mask == "bool":
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
mask = mx.array(mask_np)
return q_mx, k_mx, v_mx, scale, mask
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
n_kv_heads = k.shape[-3]
n_repeats = n_q_heads // n_kv_heads
B = q.shape[0]
L = q.shape[2]
kL = k.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
k = mx.expand_dims(k, 2)
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if mask is not None:
if mask == "causal":
q_offset = max(0, kL - L)
q_indices = mx.arange(q_offset, q_offset + L)
k_indices = mx.arange(kL)
mask = q_indices[:, None] >= k_indices[None]
if n_repeats > 1 and mask.ndim >= 3:
if mask.shape[-3] == 1:
mask = mx.expand_dims(mask, -3)
else:
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
if mask.dtype == mx.bool_:
scores = mx.where(mask, scores, -np.float32(np.inf))
else:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = scores @ v
if n_repeats > 1:
out = mx.reshape(out, [B, n_q_heads, L, -1])
return out
def mlx_fused_attn(q, k, v, scale, mask):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
if transpose:
q_t = mx.transpose(q, (0, 2, 1, 3))
k_t = mx.transpose(k, (0, 2, 1, 3))
v_t = mx.transpose(v, (0, 2, 1, 3))
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
return mx.transpose(o_t, (0, 2, 1, 3))
else:
return f(q, k, v, scale=scale, mask=mask)
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
q_out = q
for i in range(N_iter_func):
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
mx.eval(q_out)
return q_out
def bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
):
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
)
time_mlx_unfused = bench(
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
time_mlx_fused = bench(
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
o_mlx_unfused = do_attention(
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
atol = 1e-5 if dtype == "float32" else 2e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
)
return time_mlx_fused, time_mlx_unfused
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float16", "float32")[:1]
transposes = (False,)
# fmt: off
shapes_64 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 32, 32, 64, 32, 32),
( 1, 64, 64, 64, 32, 32),
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 8),
( 1, 2048, 2048, 64, 32, 8),
( 1, 4096, 4096, 64, 32, 8),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 8),
( 1, 2048, 2048, 80, 32, 8),
( 1, 4096, 4096, 80, 32, 8),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 8),
( 1, 2048, 2048, 128, 32, 8),
( 1, 4096, 4096, 128, 32, 8),
)
# fmt: on
shapes = shapes_64 + shapes_80 + shapes_128
masks = [None, "bool", "causal"]
print(
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
)
for dtype in dtypes:
for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
for mask_in in masks:
time_mlx_fused, time_mlx_unfused = bench_shape(
B,
qsl,
ksl,
head_dim,
n_q_heads,
n_kv_heads,
dtype,
transpose,
mask_in,
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)

View File

@@ -0,0 +1,95 @@
import argparse
import math
import mlx.core as mx
from time_utils import time_fn
L = 16384
H = 32
H_k = H // 4
D = 128
V = 128
dtype = mx.float16
loops = 10
def upproject(x, w):
if w is None:
return x
else:
return x @ w.T
def attention(q, k, v, mask=None, w=None):
def _sdpa(q, k, v):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
_, _, _, V = v.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
if mask is not None:
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
s = mx.where(m, s, mx.finfo(s.dtype).min)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, V)
for i in range(loops):
q = _sdpa(q, k, v)
q = upproject(q, w)
return q
def sdpa(q, k, v, mask=None, w=None):
for i in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
q = upproject(q, w)
return q
def time_self_attention_primitives():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
mx.eval(q, k, v, w)
time_fn(attention, q, k, v, w=w)
def time_self_attention_sdpa():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
mx.eval(q, k, v, w)
time_fn(sdpa, q, k, v, w=w)
def time_self_attention_sdpa_with_mask():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
mask = mx.full((L,), True)
mask[L // 2 :] = False
mx.eval(q, k, v, mask, w)
def sdpa_mask(*args):
return sdpa(*args, mask=mask, w=w)
def attention_mask(*args):
return attention(*args, mask=mask, w=w)
time_fn(attention_mask, q, k, v)
time_fn(sdpa_mask, q, k, v)
if __name__ == "__main__":
time_self_attention_sdpa()
time_self_attention_primitives()
time_self_attention_sdpa_with_mask()

View File

@@ -1,8 +1,8 @@
# Copyright © 2023 Apple Inc.
import argparse
import mlx.core as mx
import mlx.core as mx
from time_utils import time_fn
@@ -44,6 +44,13 @@ def time_matmul():
time_fn(mx.matmul, a, b)
def time_maximum():
a = mx.random.uniform(shape=(32, 1024, 1024))
b = mx.random.uniform(shape=(32, 1024, 1024))
mx.eval(a, b)
time_fn(mx.maximum, a, b)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
@@ -101,6 +108,7 @@ if __name__ == "__main__":
time_add()
time_matmul()
time_maximum()
time_exp()
time_negative()
time_logsumexp()

View File

@@ -0,0 +1,55 @@
import time
import mlx.core as mx
rank = mx.distributed.init().rank()
def timeit(fn, a):
# warmup
for _ in range(5):
mx.eval(fn(a))
its = 10
tic = time.perf_counter()
for _ in range(its):
mx.eval(fn(a))
toc = time.perf_counter()
ms = 1000 * (toc - tic) / its
return ms
def all_reduce_benchmark():
a = mx.ones((5, 5), mx.int32)
its_per_eval = 100
def fn(x):
for _ in range(its_per_eval):
x = mx.distributed.all_sum(x)
x = x - 1
return x
ms = timeit(fn, a) / its_per_eval
if rank == 0:
print(f"All Reduce: time per iteration {ms:.6f} (ms)")
def all_gather_benchmark():
a = mx.ones((5, 5), mx.int32)
its_per_eval = 100
def fn(x):
for _ in range(its_per_eval):
x = mx.distributed.all_gather(x)[0]
return x
ms = timeit(fn, a) / its_per_eval
if rank == 0:
print(f"All gather: time per iteration {ms:.6f} (ms)")
if __name__ == "__main__":
all_reduce_benchmark()
all_gather_benchmark()

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import time
@@ -6,7 +6,11 @@ import mlx.core as mx
def time_fn(fn, *args, **kwargs):
print(f"Timing {fn.__name__} ...", end=" ")
msg = kwargs.pop("msg", None)
if msg:
print(f"Timing {msg} ...", end=" ")
else:
print(f"Timing {fn.__name__} ...", end=" ")
# warmup
for _ in range(5):
@@ -20,3 +24,15 @@ def time_fn(fn, *args, **kwargs):
msec = 1e3 * (toc - tic) / num_iters
print(f"{msec:.5f} msec")
def measure_runtime(fn, **kwargs):
# Warmup
for _ in range(5):
fn(**kwargs)
tic = time.time()
iters = 100
for _ in range(iters):
fn(**kwargs)
return (time.time() - tic) * 1000 / iters

View File

@@ -1,56 +1,50 @@
include(CMakeParseArguments)
###############################################################################
# clang format off
#
# ##############################################################################
# Build metal library
#
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
#
# Args:
# TARGET: Custom target to be added for the metal library
# TITLE: Name of the .metallib
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
# SOURCES: List of source files
# INCLUDE_DIRS: List of include dirs
# DEPS: List of depedency files (like headers)
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
#
# clang format on
macro(mlx_build_metallib)
# Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments(
MTLLIB
""
"${oneValueArgs}"
"${multiValueArgs}"
${ARGN}
)
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
# Set output
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
# Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
# Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
-frecord-sources)
endif()
# Prepare metllib build command
# Prepare metallib build command
add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND xcrun -sdk macosx metal
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS}
${MTLLIB_SOURCES}
-o ${MTLLIB_BUILD_TARGET}
COMMAND
xcrun -sdk macosx metal
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
COMMAND_EXPAND_LISTS
COMMENT "Building ${MTLLIB_TITLE}.metallib"
VERBATIM
)
VERBATIM)
# Add metallib custom target
add_custom_target(
${MTLLIB_TARGET}
DEPENDS
${MTLLIB_BUILD_TARGET}
)
add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})
endmacro(mlx_build_metallib)
endmacro(mlx_build_metallib)

2
docs/.gitignore vendored
View File

@@ -1 +1,3 @@
src/python/_autosummary*/
src/python/nn/_autosummary*/
src/python/optimizers/_autosummary*/

50
docs/Doxyfile Normal file
View File

@@ -0,0 +1,50 @@
################################################################################
# Primary project setup. #
################################################################################
PROJECT_NAME = "MLX"
OUTPUT_DIRECTORY = build
XML_OUTPUT = xml
HTML_OUTPUT = html
STRIP_FROM_PATH = ../
INPUT = ../mlx
FILE_PATTERNS = *.h
EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = NO
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES
################################################################################
# Doxygen preprocessor / parser control. #
################################################################################
ENABLE_PREPROCESSING = YES
MACRO_EXPANSION = YES
EXPAND_ONLY_PREDEF = NO
SKIP_FUNCTION_MACROS = NO
################################################################################
# Compound extraction control. #
################################################################################
EXTRACT_ALL = YES
EXTRACT_PACKAGE = YES
EXTRACT_STATIC = YES
CASE_SENSE_NAMES = NO
################################################################################
# Docstring control / customization. #
################################################################################
JAVADOC_AUTOBRIEF = YES
################################################################################
# Warning suppression. #
################################################################################
QUIET = YES
WARN_IF_UNDOCUMENTED = NO

View File

@@ -2,12 +2,16 @@
### Setup (do once)
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
for example with `conda`:
Install Doxygen:
```
conda install sphinx
pip install sphinx-rtd-theme
brew install doxygen
```
Install Python packages:
```
pip install -r requirements.txt
```
### Build
@@ -15,7 +19,7 @@ pip install sphinx-rtd-theme
Build the docs from `mlx/docs/`
```
make html
doxygen && make html
```
View the docs by running a server in `mlx/docs/build/html/`:
@@ -26,7 +30,7 @@ python -m http.server <port>
and point your browser to `http://localhost:<port>`.
### Push to Github Pages
### Push to GitHub Pages
Check-out the `gh-pages` branch (`git switch gh-pages`) and build
the docs. Then force add the `build/html` directory:

4
docs/requirements.txt Normal file
View File

@@ -0,0 +1,4 @@
sphinx
breathe
sphinx-book-theme
mlx

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 746 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 76 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

View File

@@ -0,0 +1,33 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. add toctree option to make autodoc generate the pages
.. autoclass:: {{ objname }}
{% block attributes %}
{% if attributes %}
.. rubric:: Attributes
.. autosummary::
:toctree: .
{% for item in attributes %}
~{{ fullname }}.{{ item }}
{%- endfor %}
{% endif %}
{% endblock %}
{% block methods %}
{% if methods %}
.. rubric:: Methods
.. autosummary::
:toctree: .
{% for item in methods %}
{%- if item not in inherited_members and item != '__init__' %}
~{{ fullname }}.{{ item }}
{%- endif -%}
{%- endfor %}
{% endif %}
{% endblock %}

View File

@@ -4,16 +4,17 @@
.. autoclass:: {{ objname }}
{#{% block methods %}
{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != '__init__' %}
{%- if item not in inherited_members and item != "__init__" %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}#}
{% endblock %}

View File

@@ -5,13 +5,15 @@
import os
import subprocess
import mlx.core as mx
# -- Project information -----------------------------------------------------
project = "MLX"
copyright = "2023, MLX Contributors"
copyright = "2023, Apple"
author = "MLX Contributors"
version = "0.0.0"
release = "0.0.0"
version = ".".join(mx.__version__.split(".")[:3])
release = version
# -- General configuration ---------------------------------------------------
@@ -20,27 +22,77 @@ extensions = [
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
"breathe",
]
python_use_unqualified_type_names = True
autosummary_generate = True
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
intersphinx_mapping = {
"https://docs.python.org/3": None,
"https://numpy.org/doc/stable/": None,
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
}
breathe_projects = {"mlx": "../build/xml"}
breathe_default_project = "mlx"
templates_path = ["_templates"]
html_static_path = ["_static"]
source_suffix = ".rst"
master_doc = "index"
main_doc = "index"
highlight_language = "python"
pygments_style = "sphinx"
add_module_names = False
# -- Options for HTML output -------------------------------------------------
html_theme = "sphinx_rtd_theme"
html_theme = "sphinx_book_theme"
html_theme_options = {
"show_toc_level": 2,
"repository_url": "https://github.com/ml-explore/mlx",
"use_repository_button": True,
"navigation_with_keys": False,
"logo": {
"image_light": "_static/mlx_logo.png",
"image_dark": "_static/mlx_logo_dark.png",
},
}
html_favicon = html_theme_options["logo"]["image_light"]
# -- Options for HTMLHelp output ---------------------------------------------
htmlhelp_basename = "mlx_doc"
def setup(app):
from sphinx.util import inspect
wrapped_isfunc = inspect.isfunction
def isfunc(obj):
type_name = str(type(obj))
if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name:
return True
return wrapped_isfunc(obj)
inspect.isfunction = isfunc
# -- Options for LaTeX output ------------------------------------------------
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
latex_elements = {
"preamble": r"""
\usepackage{enumitem}
\setlistdepth{5}
\setlist[itemize,1]{label=$\bullet$}
\setlist[itemize,2]{label=$\bullet$}
\setlist[itemize,3]{label=$\bullet$}
\setlist[itemize,4]{label=$\bullet$}
\setlist[itemize,5]{label=$\bullet$}
\renewlist{itemize}{itemize}{5}
""",
}

View File

@@ -3,4 +3,5 @@
Operations
==========
.. doxygengroup:: ops
:content-only:

View File

@@ -0,0 +1,445 @@
.. _custom_metal_kernels:
Custom Metal Kernels
====================
MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example
--------------
.. currentmodule:: mlx.core
Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
def exp_elementwise(a: mx.array):
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
)
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Every time you make a kernel, a new Metal library is created and possibly
JIT compiled. To reduce the overhead from that, build the kernel once with
:func:`fast.metal_kernel` and then use it many times.
.. note::
Only pass the body of the Metal kernel in ``source``. The function
signature is generated automatically.
The full function signature will be generated using:
* The shapes/dtypes of ``inputs``
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
so we will add ``const device float16_t* inp`` to the signature.
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
in ``source``.
* The list of ``output_dtypes``
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
so we add ``device float16_t* out``.
* Template parameters passed using ``template``
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
and instantiates the template with ``custom_kernel_myexp_float<float>``.
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
These will be added as function arguments.
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
Putting this all together, the generated function signature for ``myexp`` is as follows:
.. code-block:: cpp
template <typename T>
[[kernel]] void custom_kernel_myexp_float(
const device float16_t* inp [[buffer(0)]],
device float16_t* out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
}
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
``threadgroup`` size threadgroups. For optimal performance, each thread group
dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
generated code for debugging purposes.
Using Shape/Strides
-------------------
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
is ``True`` by default. This will copy the array inputs if needed
before the kernel is launched to ensure that the memory layout is row
contiguous. Generally this makes writing the kernel easier, since we don't
have to worry about gaps or the ordering of the dims when indexing.
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
present in ``source``. We can then use MLX's built in indexing utils to fetch
the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without
relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
def exp_elementwise(a: mx.array):
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
ensure_row_contiguous=False,
)
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
# make non-contiguous
a = a[::2]
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Complex Example
-----------------------------
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
We'll start with the following MLX implementation using standard ops:
.. code-block:: python
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
@mx.custom_function
def grid_sample(x, grid):
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
assert D == 2, "Last dim of `grid` must be size 2."
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs[0]
For a reasonably sized input such as:
.. code-block:: python
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement:
``55.7ms -> 6.7ms => 8x speed up``
Grid Sample VJP
---------------
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
define its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra :func:`fast.metal_kernel` features:
* ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
* ``atomic_outputs=True``
Designate all of the kernel outputs as ``atomic`` in the function signature.
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
We can then implement the backwards pass as follows:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
T gix_mult = W / 2;
T giy_mult = H / 2;
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
assert D == 2, "Last dim of `grid` must be size 2."
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
There's an even larger speed up for the vjp:
``676.4ms -> 16.7ms => 40x speed up``

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,68 @@
Metal Debugger
==============
.. currentmodule:: mlx.core
Profiling is a key step for performance optimization. You can build MLX with
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
* Records source during Metal compilation, for later inspection while
debugging.
* Labels Metal objects such as command queues, improving capture readability.
To build with debugging enabled in Python prepend
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
work.
.. note::
To capture a GPU trace you must run the application with
``MTL_CAPTURE_ENABLED=1``.
.. code-block:: python
import mlx.core as mx
a = mx.random.uniform(shape=(512, 512))
b = mx.random.uniform(shape=(512, 512))
mx.eval(a, b)
trace_file = "mlx_trace.gputrace"
# Make sure to run with MTL_CAPTURE_ENABLED=1 and
# that the path trace_file does not already exist.
mx.metal.start_capture(trace_file)
for _ in range(10):
mx.eval(mx.add(a, b))
mx.metal.stop_capture()
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
has a great overview of all operations. Checkout the `Metal debugger
documentation`_ for more information.
.. image:: ../_static/metal_debugger/capture.png
:class: dark-light
Xcode Workflow
--------------
You can skip saving to a path by running within Xcode. First, generate an
Xcode project using CMake.
.. code-block::
mkdir build && cd build
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
open mlx.xcodeproj
Select the ``metal_capture`` example schema and run.
.. image:: ../_static/metal_debugger/schema.png
:class: dark-light
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger

121
docs/src/dev/mlx_in_cpp.rst Normal file
View File

@@ -0,0 +1,121 @@
.. _mlx_in_cpp:
Using MLX in C++
================
You can use MLX in a C++ project with CMake.
.. note::
This guide is based one the following `example using MLX in C++
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
First install MLX:
.. code-block:: bash
pip install -U mlx
You can also install the MLX Python package from source or just the C++
library. For more information see the :ref:`documentation on installing MLX
<build_and_install>`.
Next make an example program in ``example.cpp``:
.. code-block:: C++
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
auto x = mx::array({1, 2, 3});
auto y = mx::array({1, 2, 3});
std::cout << x + y << std::endl;
return 0;
}
The next step is to setup a CMake file in ``CMakeLists.txt``:
.. code-block:: cmake
cmake_minimum_required(VERSION 3.27)
project(example LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Depending on how you installed MLX, you may need to tell CMake where to
find it.
If you installed MLX with Python, then add the following to the CMake file:
.. code-block:: cmake
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
If you installed the MLX C++ package to a system path, then CMake should be
able to find it. If you installed it to a non-standard location or CMake can't
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
.. code-block:: cmake
set(MLX_ROOT "/path/to/mlx/")
Next, instruct CMake to find MLX:
.. code-block:: cmake
find_package(MLX CONFIG REQUIRED)
Finally, add the ``example.cpp`` program as an executable and link MLX.
.. code-block:: cmake
add_executable(example example.cpp)
target_link_libraries(example PRIVATE mlx)
You can build the example with:
.. code-block:: bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
And run it with:
.. code-block:: bash
./build/example
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
.. list-table:: Package Variables
:widths: 20 20
:header-rows: 1
* - Variable
- Description
* - MLX_FOUND
- ``True`` if MLX is found
* - MLX_INCLUDE_DIRS
- Include directory
* - MLX_LIBRARIES
- Libraries to link against
* - MLX_CXX_FLAGS
- Additional compiler flags
* - MLX_BUILD_ACCELERATE
- ``True`` if MLX was built with Accelerate
* - MLX_BUILD_METAL
- ``True`` if MLX was built with Metal

View File

@@ -15,7 +15,7 @@ module to concisely define the model architecture.
Attention layer
^^^^^^^^^^^^^^^^
We will start with the llama attention layer which notably uses the RoPE
We will start with the Llama attention layer which notably uses the RoPE
positional encoding. [1]_ In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to
support efficient inference.
@@ -321,7 +321,7 @@ which can then be used to update the model. Note that the method above incurs
several unnecessary copies from disk to numpy and then from numpy to MLX. It
will be replaced in the future with direct loading to MLX.
You can download the full example code in `mlx-examples <code>`_. Assuming, the
You can download the full example code in `mlx-examples`_. Assuming, the
existence of ``weights.pth`` and ``tokenizer.model`` in the current working
directory we can play around with our inference script as follows (the timings
are representative of an M1 Ultra and the 7B parameter Llama model):
@@ -369,9 +369,9 @@ Scripts
.. admonition:: Download the code
The full example code is available in `mlx-examples <code>`_.
The full example code is available in `mlx-examples`_.
.. code: `https://github.com/ml-explore/mlx-examples/tree/main/llama`_
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama
.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
Roformer: Enhanced transformer with rotary position embedding. arXiv

View File

@@ -61,7 +61,10 @@ set:
def eval_fn(model, X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
Next, setup the problem parameters and load the data:
Next, setup the problem parameters and load the data. To load the data, you need our
`mnist data loader
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
we will import as ``mnist``.
.. code-block:: python
@@ -127,5 +130,5 @@ Finally, we put it all together by instantiating the model, the
This should not be confused with :func:`mlx.core.value_and_grad`.
The model should train to a decent accuracy (about 95%) after just a few passes
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mlp>`_
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mnist>`_
is available in the MLX GitHub repo.

View File

@@ -1,6 +1,30 @@
MLX
===
MLX is a NumPy-like array framework designed for efficient and flexible machine
learning on Apple silicon, brought to you by Apple machine learning research.
The Python API closely follows NumPy with a few exceptions. MLX also has a
fully featured C++ API which closely follows the Python API.
The main differences between MLX and NumPy are:
- **Composable function transformations**: MLX has composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Multi-device**: Operations can run on any of the supported devices (CPU,
GPU, ...)
The design of MLX is inspired by frameworks like `PyTorch
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
`ArrayFire <https://arrayfire.org/>`_. A notable difference from these
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
memory. Operations on MLX arrays can be performed on any of the supported
device types without performing data copies. Currently supported device types
are the CPU and GPU.
.. toctree::
:caption: Install
:maxdepth: 1
@@ -11,8 +35,17 @@ MLX
:caption: Usage
:maxdepth: 1
quick_start
using_streams
usage/quick_start
usage/lazy_evaluation
usage/unified_memory
usage/indexing
usage/saving_and_loading
usage/function_transforms
usage/compile
usage/numpy
usage/distributed
usage/using_streams
usage/export
.. toctree::
:caption: Examples
@@ -27,13 +60,20 @@ MLX
:maxdepth: 1
python/array
python/data_types
python/devices_and_streams
python/export
python/ops
python/random
python/transforms
python/fast
python/fft
python/linalg
python/metal
python/memory_management
python/nn
python/optimizers
python/distributed
python/tree_utils
.. toctree::
@@ -47,3 +87,6 @@ MLX
:maxdepth: 1
dev/extensions
dev/metal_debugger
dev/custom_metal_kernels
dev/mlx_in_cpp

View File

@@ -1,8 +1,10 @@
.. _build_and_install:
Build and Install
=================
Install from PyPI
-----------------
Python Installation
-------------------
MLX is available on PyPI. All you have to do to use MLX with your own Apple
silicon computer is
@@ -11,6 +13,41 @@ silicon computer is
pip install mlx
To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.9
- macOS >= 13.5
.. note::
MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma)
MLX is also available on conda-forge. To install MLX with conda do:
.. code-block:: shell
conda install conda-forge::mlx
Troubleshooting
^^^^^^^^^^^^^^^
*My OS and Python versions are in the required range but pip still does not find
a matching distribution.*
Probably you are using a non-native Python. The output of
.. code-block:: shell
python -c "import platform; print(platform.processor())"
should be ``arm``. If it is ``i386`` (and you have M series machine) then you
are using a non-native Python. Switch your Python to a native Python. A good
way to do this is with `Conda <https://stackoverflow.com/q/65415996>`_.
Build from source
-----------------
@@ -18,8 +55,12 @@ Build Requirements
^^^^^^^^^^^^^^^^^^
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make``
- Xcode >= 15.0 and macOS SDK >= 14.0
.. note::
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
the output of ``uname -p`` is ``x86``, see the :ref:`troubleshooting section <build shell>` below.
Python API
^^^^^^^^^^
@@ -31,33 +72,38 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
Then simply build and install MLX using pip:
.. code-block:: shell
pip install "pybind11[global]"
conda install pybind11
brew install pybind11
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
Then simply build and install it using pip:
For developing, install the package with development dependencies, and use an
editable install:
.. code-block:: shell
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
For developing use an editable install:
Once the development dependencies are installed, you can build faster with:
.. code-block:: shell
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
To make sure the install is working run the tests with:
Run the tests with:
.. code-block:: shell
python -m unittest discover python/tests
Optional: Install stubs to enable auto completions and type checking from your
IDE:
.. code-block:: shell
python setup.py generate_stubs
C++ API
^^^^^^^
@@ -76,7 +122,7 @@ Create a build directory and run CMake and make:
.. code-block:: shell
mkdir -p build && cd build
cmake .. && make -j
cmake .. && make -j
Run tests with:
@@ -95,7 +141,7 @@ directory as the executable statically linked to ``libmlx.a`` or the
preprocessor constant ``METAL_PATH`` should be defined at build time and it
should point to the path to the built metal library.
.. list-table:: Build Options
.. list-table:: Build Options
:widths: 25 8
:header-rows: 1
@@ -109,5 +155,115 @@ should point to the path to the built metal library.
- OFF
* - MLX_BUILD_METAL
- ON
* - MLX_BUILD_CPU
- ON
* - MLX_BUILD_PYTHON_BINDINGS
- OFF
* - MLX_METAL_DEBUG
- OFF
* - MLX_BUILD_SAFETENSORS
- ON
* - MLX_BUILD_GGUF
- ON
* - MLX_METAL_JIT
- OFF
.. note::
If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the
following environment variable before building
.. code-block:: shell
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
Further, you can use the following command to find out which
macOS SDK will be used
.. code-block:: shell
xcrun -sdk macosx --show-sdk-version
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
and ``BUILD_SHARED_LIBS=ON``.
The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and
GGUF, you can do:
.. code-block:: shell
cmake .. \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots.
Troubleshooting
^^^^^^^^^^^^^^^
Metal not found
~~~~~~~~~~~~~~~
You see the following error when you try to build:
.. code-block:: shell
error: unable to find utility "metal", not a developer tool or in PATH
To fix this, first make sure you have Xcode installed:
.. code-block:: shell
xcode-select --install
Then set the active developer directory:
.. code-block:: shell
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
x86 Shell
~~~~~~~~~
.. _build shell:
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
Rosetta instead of natively.
To fix this, find the application in Finder (``/Applications`` for iTerm,
``/Applications/Utilities`` for Terminal), right-click, and click “Get Info”.
Uncheck “Open using Rosetta”, close the “Get Info” window, and restart your
terminal.
Verify the terminal is now running natively the following command:
.. code-block:: shell
$ uname -p
arm
Also check that cmake is using the correct architecture:
.. code-block:: shell
$ cmake --system-information | grep CMAKE_HOST_SYSTEM_PROCESSOR
CMAKE_HOST_SYSTEM_PROCESSOR "arm64"
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
but the build errors out with "Building for x86_64 on macOS is not supported."
wipe your build cache with ``rm -rf build/`` and try again.

View File

@@ -10,36 +10,56 @@ Array
array
array.astype
array.at
array.item
array.tolist
array.dtype
array.itemsize
array.nbytes
array.ndim
array.shape
array.size
Dtype
array.real
array.imag
array.abs
array.all
array.any
array.argmax
array.argmin
array.conj
array.cos
array.dtype
array.cummax
array.cummin
array.cumprod
array.cumsum
array.diag
array.diagonal
array.exp
array.flatten
array.log
array.log10
array.log1p
array.log2
array.logcumsumexp
array.logsumexp
array.max
array.mean
array.min
array.moveaxis
array.prod
array.reciprocal
array.reshape
array.round
array.rsqrt
array.sin
array.split
array.sqrt
array.square
array.squeeze
array.std
array.sum
array.swapaxes
array.transpose
array.T
array.var
array.view

View File

@@ -1,7 +1,5 @@
.. _data_types:
:orphan:
Data Types
==========
@@ -29,9 +27,9 @@ The default floating point type is ``float32`` and the default integer type is
* - ``uint32``
- 4
- 32-bit unsigned integer
* - ``uint32``
* - ``uint64``
- 8
- 32-bit unsigned integer
- 64-bit unsigned integer
* - ``int8``
- 1
- 8-bit signed integer
@@ -44,9 +42,37 @@ The default floating point type is ``float32`` and the default integer type is
* - ``int64``
- 8
- 64-bit signed integer
* - ``bfloat16``
- 2
- 16-bit brain float (e8, m7)
* - ``float16``
- 2
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
- 16-bit IEEE float (e5, m10)
* - ``float32``
- 4
- 32-bit float
* - ``float64``
- 4
- 64-bit double
* - ``complex64``
- 8
- 64-bit complex float
.. note::
Arrays with type ``float64`` only work with CPU operations. Using
``float64`` arrays on the GPU will result in an exception.
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
documentation for more information. Use :func:`issubdtype` to determine if one
``dtype`` (or category) is a subtype of another category.
.. autosummary::
:toctree: _autosummary
Dtype
DtypeCategory
issubdtype
finfo

View File

@@ -9,9 +9,11 @@ Devices and Streams
:toctree: _autosummary
Device
Stream
default_device
set_default_device
Stream
default_stream
new_stream
set_default_stream
stream
synchronize

View File

@@ -0,0 +1,22 @@
.. _distributed:
.. currentmodule:: mlx.core.distributed
Distributed Communication
==========================
MLX provides a distributed communication package using MPI. The MPI library is
loaded at runtime; if MPI is available then distributed communication is also
made available.
.. autosummary::
:toctree: _autosummary
Group
is_available
init
all_sum
all_gather
send
recv
recv_like

View File

@@ -0,0 +1,14 @@
.. _export:
Export Functions
================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
export_function
import_function
exporter
export_to_dot

15
docs/src/python/fast.rst Normal file
View File

@@ -0,0 +1,15 @@
.. _fast:
Fast
====
.. currentmodule:: mlx.core.fast
.. autosummary::
:toctree: _autosummary
rms_norm
layer_norm
rope
scaled_dot_product_attention
metal_kernel

View File

@@ -20,3 +20,5 @@ FFT
irfft2
rfftn
irfftn
fftshift
ifftshift

View File

@@ -0,0 +1,27 @@
.. _linalg:
Linear Algebra
==============
.. currentmodule:: mlx.core.linalg
.. autosummary::
:toctree: _autosummary
inv
tri_inv
norm
cholesky
cholesky_inv
cross
qr
svd
eigvals
eig
eigvalsh
eigh
lu
lu_factor
pinv
solve
solve_triangular

View File

@@ -0,0 +1,16 @@
Memory Management
=================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache

12
docs/src/python/metal.rst Normal file
View File

@@ -0,0 +1,12 @@
Metal
=====
.. currentmodule:: mlx.core.metal
.. autosummary::
:toctree: _autosummary
is_available
device_info
start_capture
stop_capture

View File

@@ -64,7 +64,6 @@ Quick Start with Neural Networks
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
.. _module_class:
The Module Class
@@ -86,20 +85,58 @@ name should not start with ``_``). It can be arbitrarily nested in other
:meth:`Module.parameters` can be used to extract a nested dictionary with all
the parameters of a module and its submodules.
A :class:`Module` can also keep track of "frozen" parameters.
:meth:`Module.trainable_parameters` returns only the subset of
:meth:`Module.parameters` that is not frozen. When using
:meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these
trainable parameters.
A :class:`Module` can also keep track of "frozen" parameters. See the
:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`
the gradients returned will be with respect to these trainable parameters.
Updating the parameters
Updating the Parameters
^^^^^^^^^^^^^^^^^^^^^^^
MLX modules allow accessing and updating individual parameters. However, most
times we need to update large subsets of a module's parameters. This action is
performed by :meth:`Module.update`.
performed by :meth:`Module.update`.
Value and grad
Inspecting Modules
^^^^^^^^^^^^^^^^^^
The simplest way to see the model architecture is to print it. Following along with
the above example, you can print the ``MLP`` with:
.. code-block:: python
print(mlp)
This will display:
.. code-block:: shell
MLP(
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
)
To get more detailed information on the arrays in a :class:`Module` you can use
:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of
all the parameters in a :class:`Module` do:
.. code-block:: python
from mlx.utils import tree_map
shapes = tree_map(lambda p: p.shape, mlp.parameters())
As another example, you can count the number of parameters in a :class:`Module`
with:
.. code-block:: python
from mlx.utils import tree_flatten
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
Value and Grad
--------------
Using a :class:`Module` does not preclude using MLX's high order function
@@ -136,37 +173,13 @@ In detail:
:toctree: _autosummary
value_and_grad
quantize
average_gradients
Neural Network Layers
---------------------
.. toctree::
.. autosummary::
:toctree: _autosummary
:template: nn-module-template.rst
Embedding
ReLU
GELU
SiLU
Linear
Conv1d
Conv2d
LayerNorm
RMSNorm
GroupNorm
RoPE
MultiHeadAttention
Sequential
Layers without parameters (e.g. activation functions) are also provided as
simple functions.
.. autosummary::
:toctree: _autosummary_functions
:template: nn-module-template.rst
gelu
gelu_approx
gelu_fast_approx
relu
silu
nn/module
nn/layers
nn/functions
nn/losses
nn/init

View File

@@ -0,0 +1,39 @@
.. _nn_functions:
.. currentmodule:: mlx.nn
Functions
---------
Layers without parameters (e.g. activation functions) are also provided as
simple functions.
.. autosummary::
:toctree: _autosummary_functions
:template: nn-module-template.rst
elu
celu
gelu
gelu_approx
gelu_fast_approx
glu
hard_shrink
hard_tanh
hardswish
leaky_relu
log_sigmoid
log_softmax
mish
prelu
relu
relu6
selu
sigmoid
silu
softmax
softmin
softplus
softshrink
step
tanh

View File

@@ -0,0 +1,45 @@
.. _init:
.. currentmodule:: mlx.nn.init
Initializers
------------
The ``mlx.nn.init`` package contains commonly used initializers for neural
network parameters. Initializers return a function which can be applied to any
input :obj:`mlx.core.array` to produce an initialized output.
For example:
.. code:: python
import mlx.core as mx
import mlx.nn as nn
init_fn = nn.init.uniform()
# Produces a [2, 2] uniform matrix
param = init_fn(mx.zeros((2, 2)))
To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
distribution, you can do:
.. code:: python
import mlx.nn as nn
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
init_fn = nn.init.uniform(low=-0.1, high=0.1)
model.apply(init_fn)
.. autosummary::
:toctree: _autosummary
constant
normal
uniform
identity
glorot_normal
glorot_uniform
he_normal
he_uniform

View File

@@ -0,0 +1,69 @@
.. _layers:
.. currentmodule:: mlx.nn
Layers
------
.. autosummary::
:toctree: _autosummary
:template: nn-module-template.rst
ALiBi
AvgPool1d
AvgPool2d
AvgPool3d
BatchNorm
CELU
Conv1d
Conv2d
Conv3d
ConvTranspose1d
ConvTranspose2d
ConvTranspose3d
Dropout
Dropout2d
Dropout3d
Embedding
ELU
GELU
GLU
GroupNorm
GRU
HardShrink
HardTanh
Hardswish
InstanceNorm
LayerNorm
LeakyReLU
Linear
LogSigmoid
LogSoftmax
LSTM
MaxPool1d
MaxPool2d
MaxPool3d
Mish
MultiHeadAttention
PReLU
QuantizedEmbedding
QuantizedLinear
RMSNorm
ReLU
ReLU6
RNN
RoPE
SELU
Sequential
Sigmoid
SiLU
SinusoidalPositionalEncoding
Softmin
Softshrink
Softsign
Softmax
Softplus
Step
Tanh
Transformer
Upsample

View File

@@ -0,0 +1,25 @@
.. _losses:
.. currentmodule:: mlx.nn.losses
Loss Functions
--------------
.. autosummary::
:toctree: _autosummary_functions
:template: nn-module-template.rst
binary_cross_entropy
cosine_similarity_loss
cross_entropy
gaussian_nll_loss
hinge_loss
huber_loss
kl_div_loss
l1_loss
log_cosh_loss
margin_ranking_loss
mse_loss
nll_loss
smooth_l1_loss
triplet_loss

View File

@@ -1,7 +1,38 @@
mlx.nn.Module
=============
Module
======
.. currentmodule:: mlx.nn
.. autoclass:: Module
:members:
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Module.training
Module.state
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Module.apply
Module.apply_to_modules
Module.children
Module.eval
Module.filter_and_map
Module.freeze
Module.leaf_modules
Module.load_weights
Module.modules
Module.named_modules
Module.parameters
Module.save_weights
Module.set_dtype
Module.train
Module.trainable_parameters
Module.unfreeze
Module.update
Module.update_modules

View File

@@ -5,13 +5,14 @@ Operations
.. currentmodule:: mlx.core
.. autosummary::
.. autosummary::
:toctree: _autosummary
abs
add
addmm
all
allclose
allclose
any
arange
arccos
@@ -19,76 +20,164 @@ Operations
arcsin
arcsinh
arctan
arctan2
arctanh
argmax
argmin
argpartition
argsort
array_equal
as_strided
atleast_1d
atleast_2d
atleast_3d
bitwise_and
bitwise_invert
bitwise_or
bitwise_xor
block_masked_mm
broadcast_arrays
broadcast_to
ceil
clip
concatenate
contiguous
conj
conjugate
convolve
conv1d
conv2d
conv3d
conv_transpose1d
conv_transpose2d
conv_transpose3d
conv_general
cos
cosh
cummax
cummin
cumprod
cumsum
degrees
dequantize
diag
diagonal
divide
divmod
einsum
einsum_path
equal
erf
erfinv
exp
expm1
expand_dims
eye
flatten
floor
floor_divide
full
gather_mm
gather_qmm
greater
greater_equal
hadamard_transform
identity
imag
inner
isfinite
isclose
isinf
isnan
isneginf
isposinf
issubdtype
kron
left_shift
less
less_equal
linspace
load
log
log2
log10
log1p
logaddexp
logcumsumexp
logical_not
logical_and
logical_or
logsumexp
matmul
max
maximum
mean
meshgrid
min
minimum
moveaxis
multiply
nan_to_num
negative
not_equal
ones
ones_like
outer
partition
pad
power
prod
put_along_axis
quantize
quantized_matmul
radians
real
reciprocal
remainder
repeat
reshape
right_shift
roll
round
rsqrt
save
savez
savez_compressed
save_gguf
save_safetensors
sigmoid
sign
sin
sinh
slice
slice_update
softmax
sort
split
sqrt
square
squeeze
stack
std
stop_gradient
subtract
sum
swapaxes
take
take_along_axis
tan
tanh
tensordot
tile
topk
trace
transpose
tri
tril
triu
unflatten
var
view
where
zeros
zeros_like

View File

@@ -1,5 +1,7 @@
.. _optimizers:
.. currentmodule:: mlx.optimizers
Optimizers
==========
@@ -29,13 +31,48 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
.. currentmodule:: mlx.optimizers
Saving and Loading
------------------
To serialize an optimizer, save its state. To load an optimizer, load and set
the saved state. Here's a simple example:
.. code-block:: python
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
import mlx.optimizers as optim
optimizer = optim.Adam(learning_rate=1e-2)
# Perform some updates with the optimizer
model = {"w" : mx.zeros((5, 5))}
grads = {"w" : mx.ones((5, 5))}
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
parameters are not. A good rule of thumb is if the parameter can be scheduled
then it will be included in the optimizer state.
.. toctree::
optimizers/optimizer
optimizers/common_optimizers
optimizers/schedulers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
OptimizerState
Optimizer
SGD
Adam
clip_grad_norm

View File

@@ -0,0 +1,21 @@
.. _common_optimizers:
Common Optimizers
=================
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW
Adamax
Lion
MultiOptimizer

View File

@@ -0,0 +1,23 @@
Optimizer
=========
.. currentmodule:: mlx.optimizers
.. autoclass:: Optimizer
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Optimizer.state
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Optimizer.apply_gradients
Optimizer.init
Optimizer.update

View File

@@ -0,0 +1,15 @@
.. _schedulers:
Schedulers
==========
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
cosine_decay
exponential_decay
join_schedules
linear_schedule
step_decay

View File

@@ -33,13 +33,16 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
.. autosummary::
:toctree: _autosummary
seed
key
split
bernoulli
categorical
gumbel
key
normal
multivariate_normal
randint
uniform
seed
split
truncated_normal
uniform
laplace
permutation

View File

@@ -9,6 +9,11 @@ Transforms
:toctree: _autosummary
eval
async_eval
compile
custom_function
disable_compile
enable_compile
grad
value_and_grad
jvp

View File

@@ -19,3 +19,5 @@ return python trees will be using the default python ``dict``, ``list`` and
tree_flatten
tree_unflatten
tree_map
tree_map_with_path
tree_reduce

497
docs/src/usage/compile.rst Normal file
View File

@@ -0,0 +1,497 @@
.. _compile:
Compilation
===========
.. currentmodule:: mlx.core
MLX has a :func:`compile` function transformation which compiles computation
graphs. Function compilation results in smaller graphs by merging common work
and fusing certain operations. In many cases this can lead to big improvements
in run-time and memory use.
Getting started with :func:`compile` is simple, but there are some edge cases
that are good to be aware of for more complex graphs and advanced usage.
Basics of Compile
-----------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
# Regular call, no compilation
# Prints: array(2.36788, dtype=float32)
print(fun(x, y))
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same
up to numerical precision.
The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
function multiple times will not initiate a new compilation. This means you
should typically compile functions that you plan to use more than once.
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
compiled_fun = mx.compile(fun)
# Compiled here
compiled_fun(x, y)
# Not compiled again
compiled_fun(x, y)
# Not compiled again
mx.compile(fun)(x, y)
There are some important cases to be aware of that can cause a function to
be recompiled:
* Changing the shape or number of dimensions
* Changing the type of any of the inputs
* Changing the number of inputs to the function
In certain cases only some of the compilation stack will be rerun (for
example when changing the shapes) and in other cases the full compilation
stack will be rerun (for example when changing the types). In general you
should avoid compiling functions too frequently.
Another idiom to watch out for is compiling functions which get created and
destroyed frequently. This can happen, for example, when compiling an anonymous
function in a loop:
.. code-block:: python
a = mx.array(1.0)
# Don't do this, compiles lambda at each iteration
for _ in range(5):
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
Example Speedup
---------------
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
Transformer-based models. The implementation involves several unary and binary
element-wise operations:
.. code-block:: python
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you
use it with large arrays it will be memory bandwidth bound. However, all of
the operations in the ``gelu`` are fusible into a single kernel with
:func:`compile`. This can speedup both cases considerably.
Let's compare the runtime of the regular function versus the compiled
function. We'll use the following timing helper which does a warm up and
handles synchronization:
.. code-block:: python
import time
def timeit(fun, x):
# warm up
for _ in range(10):
mx.eval(fun(x))
tic = time.perf_counter()
for _ in range(100):
mx.eval(fun(x))
toc = time.perf_counter()
tpi = 1e3 * (toc - tic) / 100
print(f"Time per iteration {tpi:.3f} (ms)")
Now make an array, and benchmark both functions:
.. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
Debugging
---------
When a compiled function is first called, it is traced with placeholder
inputs. This means you can't evaluate arrays (for example to print their
contents) inside compiled functions.
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Crash
return mx.exp(z)
fun(mx.array(5.0))
For debugging, inspecting arrays can be helpful. One way to do that is to
globally disable compilation using the :func:`disable_compile` function or
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
``fun`` is compiled:
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Okay
return mx.exp(z)
mx.disable_compile()
fun(mx.array(5.0))
Pure Functions
--------------
Compiled functions are intended to be *pure*; that is they should not have side
effects. For example:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0))
# Crash!
print(state)
After the first call of ``fun``, the ``state`` list will hold a placeholder
array. The placeholder does not have any data; it is only used to build the
computation graph. Printing such an array results in a crash.
You have two options to deal with this. The first option is to simply return
``state`` as an output:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
_, state = fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
In some cases returning updated state can be pretty inconvenient. Hence,
:func:`compile` has a parameter to capture implicit outputs:
.. code-block:: python
from functools import partial
state = []
# Tell compile to capture state as an output
@partial(mx.compile, outputs=state)
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
This is particularly useful for compiling a function which includes an update
to a container of arrays, as is commonly done when training the parameters of a
:class:`mlx.nn.Module`.
Compiled functions will also treat any inputs not in the parameter list as
constants. For example:
.. code-block:: python
state = [mx.array(1.0)]
@mx.compile
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Still prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
In order to have the change of state reflected in the outputs of ``fun`` you
again have two options. The first option is to simply pass ``state`` as input
to the function. In some cases this can be pretty inconvenient. Hence,
:func:`compile` also has a parameter to capture implicit inputs:
.. code-block:: python
from functools import partial
state = [mx.array(1.0)]
# Tell compile to capture state as an input
@partial(mx.compile, inputs=state)
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Prints array(6, dtype=float32)
print(fun(mx.array(1.0)))
Compiling Training Graphs
-------------------------
This section will step through how to use :func:`compile` with a simple example
of a common setup: training a model with :obj:`mlx.nn.Module` using an
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Perform 10 steps of gradient descent
for it in range(10):
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
# Perform 10 steps of gradient descent
for it in range(10):
loss = step(x, y)
# Evaluate the model and optimizer state
mx.eval(state)
print(loss)
.. note::
If you are using a module which performs random sampling such as
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
``state`` captured by :func:`compile`, i.e. ``state = [model.state,
optimizer.state, mx.random.state]``.
.. note::
For more examples of compiling full training graphs checkout the `MLX
Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.
Transformations with Compile
----------------------------
In MLX function transformations are composable. You can apply any function
transformation to the output of any other function transformation. For more on
this, see the documentation on :ref:`function transforms
<function_transforms>`.
Compiling transformed functions works just as expected:
.. code-block:: python
grad_fn = mx.grad(mx.exp)
compiled_grad_fn = mx.compile(grad_fn)
# Prints: array(2.71828, dtype=float32)
print(grad_fn(mx.array(1.0)))
# Also prints: array(2.71828, dtype=float32)
print(compiled_grad_fn(mx.array(1.0)))
.. note::
In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile`
the most opportunity to optimize the computation graph:
.. code-block:: python
@mx.compile
def inner(x):
return mx.exp(-mx.abs(x))
def outer(x):
inner(inner(x))
# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)
.. _shapeless_compile:
Shapeless Compilation
---------------------
When the shape of an input to a compiled function changes, the function is
recompiled. You can compile a function once and run it on inputs with
variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this
case changes to the shapes of the inputs do not cause the function to be
recompiled.
.. code-block:: python
def fun(x, y):
return mx.abs(x + y)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.array(1.0)
y = mx.array(-2.0)
# Firt call compiles the function
print(compiled_fun(x, y))
# Second call with different shapes
# does not recompile the function
x = mx.array([1.0, -6.0])
y = mx.array([-2.0, 3.0])
print(compiled_fun(x, y))
Use shapeless compilations carefully. Since compilation is not triggered when
shapes change, any graphs which are conditional on the input shapes will not
work as expected. Shape-dependent computations are common and sometimes subtle
to detect. For example:
.. code-block:: python
def fun(x):
return x.reshape(x.shape[0] * x.shape[1], -1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Error, can't reshape (5, 5, 3) to (6, -1)
out = compiled_fun(x)
The second call to the ``compiled_fun`` fails because of the call to
:func:`reshape` which uses the static shape of ``x`` in the first call. We can
fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``:
.. code-block:: python
def fun(x):
return x.flatten(0, 1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Ok
out = compiled_fun(x)

Some files were not shown because too many files have changed in this diff Show More