Compare commits

...

254 Commits

Author SHA1 Message Date
CircleCI Docs
57a4334bbc rebase 2025-08-29 17:13:50 +00:00
CircleCI Docs
84d493c53c rebase 2025-08-29 17:12:22 +00:00
CircleCI Docs
4a729ea9ba rebase 2025-08-29 17:12:22 +00:00
CircleCI Docs
852cffda73 rebase 2025-08-29 17:12:22 +00:00
CircleCI Docs
0746bf174f rebase 2025-08-29 17:12:22 +00:00
CircleCI Docs
c1c6d69d53 rebase 2025-08-29 17:12:22 +00:00
CircleCI Docs
36ff32876e rebase 2025-08-29 17:12:21 +00:00
CircleCI Docs
cb9421d68a rebase 2025-08-29 17:12:21 +00:00
CircleCI Docs
03683cc507 rebase 2025-08-29 17:12:21 +00:00
CircleCI Docs
95c0bfb5ed rebase 2025-08-29 17:12:21 +00:00
CircleCI Docs
f84333aa33 rebase 2025-08-29 17:12:21 +00:00
CircleCI Docs
0f716377d1 rebase 2025-08-29 17:12:21 +00:00
CircleCI Docs
eac8a84521 rebase 2025-08-29 17:12:20 +00:00
CircleCI Docs
ea62c81b68 rebase 2025-08-29 17:12:20 +00:00
CircleCI Docs
a6c7193c6c rebase 2025-08-29 17:12:20 +00:00
CircleCI Docs
0b8425f93b rebase 2025-08-29 17:12:20 +00:00
CircleCI Docs
085ab02328 rebase 2025-08-29 17:12:20 +00:00
CircleCI Docs
b1385dbac5 rebase 2025-08-29 17:12:19 +00:00
CircleCI Docs
b7deadf44c rebase 2025-08-29 17:12:19 +00:00
CircleCI Docs
358d1cffdb rebase 2025-08-29 17:12:19 +00:00
CircleCI Docs
23fec194d8 rebase 2025-08-29 17:12:18 +00:00
CircleCI Docs
0ac71de969 rebase 2025-08-29 17:12:18 +00:00
CircleCI Docs
8a70e9e8cb rebase 2025-08-29 17:12:18 +00:00
CircleCI Docs
32195ee16f rebase 2025-08-29 17:12:18 +00:00
CircleCI Docs
abe1da8af4 rebase 2025-08-29 17:12:18 +00:00
CircleCI Docs
59c868fcbc rebase 2025-08-29 17:12:18 +00:00
CircleCI Docs
a196e0e669 rebase 2025-08-29 17:12:17 +00:00
CircleCI Docs
d1357760e4 rebase 2025-08-29 17:12:17 +00:00
CircleCI Docs
9ad625c70d rebase 2025-08-29 17:12:17 +00:00
CircleCI Docs
d1859a4f24 rebase 2025-08-29 17:12:17 +00:00
Awni Hannun
9d0d5648d9 rebase 2025-08-29 17:12:16 +00:00
Awni Hannun
eed934c954 docs update 2025-08-29 17:12:16 +00:00
Awni Hannun
1bcb45cee2 docs up 2025-08-29 17:12:16 +00:00
Awni Hannun
5f2d990100 docs update 2025-08-29 17:12:16 +00:00
Awni Hannun
3f1778f3a1 docs update 2025-08-29 17:12:16 +00:00
Awni Hannun
92badab745 docs update 2025-08-29 17:12:16 +00:00
Awni Hannun
e95b4c1f0e docs update 2025-08-29 17:12:15 +00:00
Awni Hannun
8c272db45b docs update 2025-08-29 17:12:15 +00:00
Awni Hannun
da6b809825 docs update 2025-08-29 17:12:15 +00:00
Awni Hannun
3aabcb2850 docs 2025-08-29 17:12:15 +00:00
Awni Hannun
62608cbd7e docs update 2025-08-29 17:12:14 +00:00
Awni Hannun
061d6d3979 docs update 2025-08-29 17:12:14 +00:00
Awni Hannun
e4a6ad6701 docs update 2025-08-29 17:12:14 +00:00
Awni Hannun
dcc04ea7f4 docs update 2025-08-29 17:12:14 +00:00
Awni Hannun
a10b1c4457 docs update 2025-08-29 17:12:14 +00:00
Awni Hannun
21c71fca27 use proper version 2025-08-29 17:12:14 +00:00
Awni Hannun
60e0eb6b5d docs update 2025-08-29 17:12:14 +00:00
Awni Hannun
ffc0be0bdf docs update 2025-08-29 17:12:13 +00:00
Awni Hannun
76f2d4a67f docs update 2025-08-29 17:12:13 +00:00
Awni Hannun
6ea4f86f7f docs update 2025-08-29 17:12:13 +00:00
Awni Hannun
023b343eff remove uneeded files in docs 2025-08-29 17:12:13 +00:00
Awni Hannun
fec431f299 update docs 2025-08-29 17:12:13 +00:00
Awni Hannun
8680b4a35e docs update 2025-08-29 17:12:13 +00:00
Awni Hannun
26e68230dc docs up 2025-08-29 17:12:13 +00:00
Awni Hannun
05bccb46b1 docs up 2025-08-29 17:12:13 +00:00
Awni Hannun
b1b0991896 docs update 2025-08-29 17:12:13 +00:00
Awni Hannun
71656654d6 docs 2025-08-29 17:12:13 +00:00
Awni Hannun
76c94fab2e docs 2025-08-29 17:12:13 +00:00
Awni Hannun
d25d7ea240 update docs 2025-08-29 17:12:13 +00:00
Awni Hannun
55a57ba1fb docs 2025-08-29 17:12:13 +00:00
Awni Hannun
0484e62fbd docs 2025-08-29 17:12:13 +00:00
Awni Hannun
83b7891525 docs 2025-08-29 17:12:13 +00:00
Awni Hannun
3128dc2560 docs 2025-08-29 17:12:13 +00:00
Awni Hannun
c6af74f07d docs 2025-08-29 17:12:12 +00:00
Awni Hannun
cf742646fc docs 2025-08-29 17:12:12 +00:00
Awni Hannun
a08ce1389c docs 2025-08-29 17:12:12 +00:00
Awni Hannun
beb994875b docs 2025-08-29 17:12:12 +00:00
Awni Hannun
cfff02c477 docs 2025-08-29 17:12:12 +00:00
Awni Hannun
8ce49cd39e fix quantized vjp for mxfp4 (#2555) 2025-08-29 10:06:15 -07:00
Awni Hannun
9c68b50853 version bump (#2554) 2025-08-29 06:54:17 -07:00
Awni Hannun
111f1e71af Faster contiguous gather for indices in the first axis (#2552)
* faster contiguous gather for indices in the first axis

* work per thread > 1

* angelos suggestion for scales / biases
2025-08-28 21:26:30 -07:00
Awni Hannun
827003d568 fix METAL quantization in JIT (#2553) 2025-08-28 18:26:25 -07:00
Awni Hannun
d363a76aa4 Bump xcode in circle (#2551)
* bump xcode in circle

* bump xcode in circle

* bump xcode in circle
2025-08-28 13:13:34 -07:00
Awni Hannun
70560b6bd5 Add mode parameter for quantization (#2499)
* add mode parameter for quantization

* mxfp4 quantize/dequantize + start of optional biases

* mxfp4 works

* speedup

* cpu mxfp4

* fix

* fix test tol

* fix

* refactor

* add quant mode enum
2025-08-28 06:45:26 -07:00
Awni Hannun
7ef8a6f2d5 [CUDA] fix sort (#2550)
* [CUDA] fix sort

* fix test
2025-08-27 19:48:43 -07:00
Cheng
31c6f6e33f [CUDA] Use ConcurrentContext in concatenate_gpu (#2549) 2025-08-28 09:30:08 +09:00
Awni Hannun
584d48458e link with nccl (#2546) 2025-08-27 10:01:07 -07:00
Cheng
5cf984ca87 Separate cpu compilation cache by versions (#2548) 2025-08-27 11:25:15 +09:00
Cheng
a9bac3d9e5 Run CPP tests for CUDA build in CI (#2544) 2025-08-27 08:06:46 +09:00
Awni Hannun
5458d43247 add load with path tests (#2543) 2025-08-26 14:24:47 -07:00
Awni Hannun
a4dba65220 Enable cuda graph toggle (#2545)
* enable cuda graph toggle

* increase cache size
2025-08-26 12:50:38 -07:00
Awni Hannun
3dcb286baf Remove stream from average grads so it uses default (#2532)
* Remove stream from average grads so it uses default

* comment
2025-08-25 15:56:29 -07:00
Cheng
4822c3dbe9 [CUDA] Implement DynamicSlice/DynamicSliceUpdate (#2533)
* Move DynamicSlice to gpu/primitives

* Implement compute_dynamic_offset in CUDA
2025-08-26 07:31:39 +09:00
Awni Hannun
2ca75bb529 Remove nccl install in release (#2542) 2025-08-25 15:20:18 -07:00
Awni Hannun
db14e29a0b allow pathlib.Path to save/load functions (#2541) 2025-08-25 14:58:49 -07:00
Awni Hannun
d2f540f4e0 Use nccl header only when nccl is not present (#2539)
* use nccl header only when nccl is not present

* larger machine for cuda build
2025-08-25 14:17:25 -07:00
Cheng
333ffea273 [CUDA] Remove thrust in arange (#2535) 2025-08-24 16:22:36 +09:00
Cheng
f55b6f1f2f Enable COMPILE_WARNING_AS_ERROR for linux builds in CI (#2534) 2025-08-24 15:33:08 +09:00
Awni Hannun
30561229c7 Fix allocation bug in NCCL (#2530) 2025-08-22 14:39:43 -07:00
Awni Hannun
068a4612e9 nccl default for backend=any (#2528)
* nccl default for backend=any

* check num gpus + ensure row contiguous for all reduce

* comment
2025-08-22 12:24:27 -07:00
Andrey Portnoy
5722c147de [CUDA] Update calls to cudaMemAdvise and cudaGraphAddDependencies for CUDA 13 (#2525)
* [CUDA] Update cudaMemAdvise and cudaGraphAddDependencies for CUDA 13

These functions' signatures changed in CUDA 13, so we differentiate
between CUDA 13 and preceding releases at compile time.

* Mention NVIDIA in ACKNOWLEDGMENTS.md
2025-08-21 19:57:20 -07:00
Cheng
f6819a1f26 Fix warning 186-D from nvcc (#2527) 2025-08-22 10:29:55 +09:00
Awni Hannun
f93f87c802 nccl dep + default for cuda (#2526) 2025-08-21 17:57:49 -07:00
Anastasiia Filippova
9392fc3f88 NCCL backend (#2476) 2025-08-21 11:56:15 -07:00
Awni Hannun
e843c4d8d5 fix power (#2523) 2025-08-21 06:46:01 -07:00
Angelos Katharopoulos
0c5fc63a36 Fix docs omission (#2524) 2025-08-20 17:56:06 -07:00
Angelos Katharopoulos
e397177f6e Custom cuda kernel (#2517) 2025-08-20 17:20:22 -07:00
Cheng
f4c8888cbe [CUDA] Fix stride of singleton dims before passing to cuDNN (#2521) 2025-08-21 08:55:26 +09:00
Angelos Katharopoulos
25c1e03205 Fix overflow in large filter small channels (#2520) 2025-08-20 08:03:29 -07:00
russellizadi
512281781c Remove state return from function example in compile documentation (#2518) 2025-08-20 00:45:05 -07:00
Cheng
ac85ddfdb7 [CUDA] Add GEMM-based fallback convolution kernels (#2511)
* Add gemm_conv

* Add gemm_grouped_conv
2025-08-20 10:06:22 +09:00
Cheng
65d0d40232 Split cuDNN helpers into a separate header (#2491)
* Add RAII managed CudaGraph class

* Implement forward rms_norm with cuDNN

* Revert back to old rms norm kernel
2025-08-20 09:29:28 +09:00
Awni Hannun
cea9369610 fix lapack svd (#2515) 2025-08-18 15:07:59 -07:00
Awni Hannun
e7c6e1db82 no segfault with uninitialized array.at (#2514) 2025-08-18 08:33:38 -07:00
Awni Hannun
c5fcd5b61b fix custom kernel test (#2510) 2025-08-18 06:45:59 -07:00
Angelos Katharopoulos
1df9887998 Ensure no oob read in gemv_masked (#2508) 2025-08-17 08:42:33 -07:00
Angelos Katharopoulos
73f22d6226 Ensure small sort doesn't use indices if not argsort (#2506) 2025-08-17 08:42:20 -07:00
Cheng
c422050ca7 Update cuDNN Frontend to v1.14 (#2505) 2025-08-17 19:13:01 +09:00
Cheng
1ba18ff7d9 [CUDA] Fix conv grads with groups (#2495)
* Put reshape utils in one file

* [CUDA] Fix conv grads with groups

* Put the reshape utils in gpu/copy.h
2025-08-16 10:09:18 +09:00
Cheng
37b440faa8 Clean up code handling both std::vector and SmallVector (#2493) 2025-08-16 09:01:10 +09:00
Cheng
888b13ed63 Remove the hack around SmallVector in cpu compile (#2494) 2025-08-16 08:17:24 +09:00
Cheng
4abb218d21 The naive_conv_2d is no longer used (#2496) 2025-08-16 07:57:30 +09:00
Awni Hannun
6441c21a94 Faster general unary op (#2472)
* faster general unary op

* faster general ops + reorg

* fix + comment

* binary two

* copy general
2025-08-15 15:04:12 -07:00
Cheng
dfb5022eab Rename cu::Matmul to CublasGemm (#2488) 2025-08-13 09:37:40 +09:00
Daniel Yeh
ac207ce7aa make code blocks copyable (#2480)
Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
2025-08-12 12:29:02 -07:00
Abe Leininger
fce53b61d6 Fix reduce sum/prod overflow (#2477) 2025-08-12 00:05:33 -07:00
Angelos Katharopoulos
8ae4a76308 Use CMake <4.1 to avoid the nvpl error (#2489) 2025-08-12 00:03:42 -07:00
Cheng
7fde1b6a1e Fix logsumexp/softmax not fused for some cases (#2474) 2025-08-08 14:07:17 -07:00
Cheng
aa7b47481a [CUDA] Optimize set_mm_device_pointers for small ndim (#2473) 2025-08-08 15:23:30 +09:00
Awni Hannun
56be773610 version (#2470) 2025-08-07 00:36:04 -07:00
Jagrit Digani
a9bdd67baa Add CUDA sdpa vector (#2468) 2025-08-06 21:40:26 -07:00
Angelos Katharopoulos
f2adb5638d Fix typo in metal command encoder (#2471) 2025-08-06 16:58:23 -07:00
Luca Vivona
728d4db582 Support destination arg in tree flatten/unflatten (#2450) 2025-08-06 15:34:59 -07:00
Awni Hannun
db5c7efcf6 revert default cuda install (#2465)
* revert default cuda install

* revert default cuda install
2025-08-06 06:19:12 -07:00
Awni Hannun
7bb96e4249 fix cublas on h100 (#2466) 2025-08-06 06:18:58 -07:00
Awni Hannun
fa89f0b150 faster gather qmm sorted test (#2463) 2025-08-05 06:27:40 -07:00
Awni Hannun
ca973d1e83 fix install tags (#2464) 2025-08-04 20:01:23 -07:00
Cheng
828c5f1137 Use SmallVector for shapes and strides (#2454)
* Use SmallVector for shapes and strides

* Convert SmallVector to tuple
2025-08-05 09:41:03 +09:00
Gaétan Lepage
7d86a5c108 Feat: add USE_SYSTEM_FMT CMake option (#2219) 2025-08-04 16:36:11 -07:00
Awni Hannun
0b807893a7 fix wraps compile (#2461) 2025-08-04 16:14:18 -07:00
Awni Hannun
6ad0889c8a default install cuda on linux (#2462) 2025-08-04 15:33:05 -07:00
Zamderax
737dd6d1ac Add missing <algorithm> header to jit_compiler.cpp (#2460)
Fixes compilation error on Linux where std::find_if is used on line 121
but the <algorithm> header was not included. While this might work on
some platforms due to transitive includes, it's not guaranteed by the
C++ standard.

Resolves issue #2459
2025-08-04 14:00:46 -07:00
Cheng
aaf78f4c6b Use LRU cache for cuda graph (#2448)
* Use LRU cache for cuda graph

* Remove unused destructor
2025-08-02 21:28:57 +09:00
Angelos Katharopoulos
8831064493 Fix arctan2 grads (#2453) 2025-08-01 21:06:04 -07:00
Angelos Katharopoulos
be9bc96da4 [CUDA] Matmul utils initial commit (#2441) 2025-08-01 14:22:25 -07:00
Angelos Katharopoulos
86258f292f [CUDA] Vectorize generated kernels (#2444) 2025-07-31 18:18:57 -07:00
Cheng
b26d88591c [CUDA] Save primitive inputs faster (#2449)
* Add more nvtx loggings

* [CUDA] Saving primitive inputs faster

* Remove unneeded check
2025-08-01 10:16:06 +09:00
Cheng
86c6a15571 [CUDA] Backward convolution (#2431) 2025-08-01 09:54:05 +09:00
junpeiz
8b25ce62d5 Add tests for export including control flow models and quantized models (#2430)
* Add tests for export, including control flow export and quantized model export.

* Skip quantization related test for CUDA backend.
2025-07-31 11:06:26 -07:00
Awni Hannun
da5912e4f2 fix custom metal extension (#2446) 2025-07-31 06:25:36 -07:00
Cheng
daafee676f Fix wrong graph key when using concurrent context (#2447) 2025-07-31 06:01:05 -07:00
Awni Hannun
d32519c8ee fix gemv regression (#2445) 2025-07-30 14:23:01 -07:00
Awni Hannun
b405591249 fix circular reference (#2443) 2025-07-30 09:37:44 -07:00
Angelos Katharopoulos
3bf81ed1bd [CUDA] Quantized refactoring (#2442) 2025-07-30 08:27:20 -07:00
Cheng
2204182bba Make CI faster (#2440) 2025-07-30 02:26:36 -07:00
Cheng
3628e5d497 Use load_vector in arg_reduce (#2439) 2025-07-30 17:40:26 +09:00
Cheng
a0ae49d397 Move arange to its own file (#2438) 2025-07-30 13:05:51 +09:00
Cheng
254476718b Remove the kernel arg from get_launch_args (#2437) 2025-07-30 11:43:02 +09:00
Awni Hannun
3adba92ebe Cuda faster softmax (#2435)
* faster softmax and logsumexp

* faster softmax and logsumexp

* format
2025-07-29 17:18:12 -07:00
Awni Hannun
ef631d63af faster rms norm (#2433) 2025-07-29 13:12:00 -07:00
Cheng
970dbe8e25 Use ccache in CI (#2414)
* Detect ccache

* Use ccache in CI

* Separate cache for different images

* Test both 12.2 and 12.9 for PRs
2025-07-29 08:43:22 +09:00
Awni Hannun
641be9463b Add more CUDA architectures for PyPi package (#2427)
* add cuda sm 90

* add more archs
2025-07-28 12:35:15 -07:00
Awni Hannun
ab0e608862 [CUDA] More sizes for gemv (#2429)
* route more to gemv

* route more sizes to custom gemv
2025-07-28 12:35:01 -07:00
Awni Hannun
1588659062 no occupancy query for launch params (#2426) 2025-07-28 09:09:41 -07:00
Awni Hannun
b9e88fb976 [CUDA] Fix segfault on exit (#2424)
* fix cuda segfault on exit

* comment
2025-07-27 08:08:13 -07:00
Awni Hannun
4ad53414dd fix cuda pypi package (#2423)
* fix cuda pypi package

* patch bump
2025-07-25 15:20:29 -07:00
Awni Hannun
d1165b215e version (#2420) 2025-07-25 13:29:28 -07:00
Awni Hannun
dcb8319f3d update install docs and requirements (#2419) 2025-07-25 12:13:19 -07:00
Awni Hannun
5597fa089c Fix qvm splitk (#2415) 2025-07-25 11:50:24 -07:00
Awni Hannun
9acec364c2 [CUDA] Always use batched matmul (#2404)
* cuda batched mm

* addmm as well

* comment
2025-07-24 20:46:02 -07:00
Skonor
7d9d6ef456 docs: fix adam and adamw eps placement (#2416)
Co-authored-by: Mikhail Gorbunov <m_gorbunov@apple.com>
2025-07-24 16:40:45 -07:00
Cheng
6f5874a2f2 [CUDA] Initial implementation of Convolution with cuDNN (#2385)
* Link with cuDNN

* Initial implementation

* Remove backend apis

* Fix recording cudnn conv

* More unused backend apis

* Fix C++ conv tests

* include cudnn as python dep

* Install libcudnn9-dev-cuda-12 in CI

* cudnn only accepts contiguous inputs

* Switch to backend apis

* Plan needs to be kept alive

* Turn off tf32

* Add cache

* Test the native cuda graph api

* Set cudnn stream before execution

* Make LRUCache more like a normal container

* Do error check for cublas handle

* Zero-initilizing array

* Use tf32 for conv

* Skip TestConv.test_torch_conv_2D test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-07-25 08:12:10 +09:00
Awni Hannun
70dc336785 Test on cuda 12.2 and 12.9 (#2413) 2025-07-24 06:06:15 -07:00
Awni Hannun
4e504039f5 [Metal] Release metal events (#2412)
* release metal events

* fix

* fix
2025-07-23 19:53:42 -07:00
Awni Hannun
d1f4d291e8 Fix uv install and add dev release (#2411)
* fix uv install and add dev release

* fix docstring

* pin cuda deps

* cuda release on cpu-only machine
2025-07-23 16:54:19 -07:00
Awni Hannun
e1840853ce full row mask in sdpa consistently gives nan (#2406) 2025-07-23 16:37:03 -07:00
Cheng
0f5ce173da [CUDA] --compress-mode requires CUDA 12.8 (#2407) 2025-07-23 06:11:11 -07:00
Cheng
588854195f Remove unused code in Convolution::vjp (#2408) 2025-07-23 06:11:00 -07:00
Fangjun Kuang
28d068bce6 Fix an error in the comment for mx.dequantize (#2409) 2025-07-23 06:10:50 -07:00
Awni Hannun
d107d8d495 add cuda gemv (#2400) 2025-07-22 08:24:13 -07:00
Awni Hannun
1e496ddb82 [CUDA] Simplify allocator (#2392)
* simplify allocator and fixe race with small pool

* Don't use shared event in worker

* use cuda buffer in small pool

* comment

* comment
2025-07-22 08:24:01 -07:00
Awni Hannun
74eccbf3fa use size option in binary (#2399) 2025-07-22 07:00:53 -07:00
Awni Hannun
08638223ca Fix including stubs in wheel (#2398)
* fix including stubs in wheel

* fix bool_
2025-07-22 06:30:17 -07:00
Cheng
56cc858af9 Add contiguous_copy_cpu util for copying array (#2397) 2025-07-21 07:30:35 -07:00
Cheng
f55c4ed1d6 Remove thrust iterators (#2396) 2025-07-21 07:30:27 -07:00
Awni Hannun
93d70419e7 [CUDA] speedup handling scalars (#2389)
* speedup scalars in cuda

* comment
2025-07-18 21:47:31 -07:00
Awni Hannun
63f663d9c6 fix cuda manylinux version to match others (#2388) 2025-07-18 21:02:16 -07:00
Awni Hannun
84b4d96efa fix release build + patch bump (#2387) 2025-07-18 14:47:37 -07:00
Awni Hannun
aec67f2fa6 patch bump (#2386) 2025-07-18 12:25:48 -07:00
Gökdeniz Gülmez
deee214a95 Adding support for the Muon Optimizer (#1914)
* initial commit with workong optmimizer

* update ACKNOWLEDGMENTS.md

* nits and adding it to test

* nits

* G.astype(mx.bfloat16) to G.astype(G.dtype)

* G.ndim >= 2 to assert G.ndim == 2

* remove coments

* replace with  mx.addmm

* remove comments

* format

* nits

* match muon

* fix addmm

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-07-18 12:25:28 -07:00
Cheng
45adec102c Add contiguous_copy_gpu util for copying array (#2379) 2025-07-18 06:44:25 -07:00
Cheng
31fc530c76 [CUDA] Add more ways finding CCCL headers in JIT (#2382) 2025-07-17 15:25:34 -07:00
Awni Hannun
fbb3f65a1a fix resource leaks in matmul and graph (#2383) 2025-07-17 06:50:15 -07:00
Angelos Katharopoulos
6b1b8ea91b [CUDA] Add work per thread to compile (#2368) 2025-07-17 06:47:52 -07:00
Awni Hannun
b2273733ea Test with CUDA 12.2 (#2375)
* Test with CUDA 12.0

* try older image

* fix cpu sort
2025-07-16 13:00:37 -07:00
Awni Hannun
f409b229a4 fix ring distributed test (#2380) 2025-07-16 11:25:24 -07:00
Cheng
30571e2326 Rename the copy util in cpu/copy.h to copy_cpu (#2378) 2025-07-16 07:34:24 -07:00
Awni Hannun
d7734edd9f fix complex reduce + nan propagation in min and max (#2377) 2025-07-15 18:19:47 -07:00
Awni Hannun
2ba69bc8fa lower memory uniform sampling (#2361)
* lower memory uniform

* use fp32

* fix
2025-07-15 14:22:07 -07:00
Cheng
cb349a291c [CUDA] Use cuda::std::complex in place of cuComplex (#2372) 2025-07-15 00:36:13 -07:00
Awni Hannun
f0a0b077a0 Install linux with mlx[cuda] and mlx[cpu] (#2356)
* install linux with mlx[cuda] and mlx[cpu]

* temp for testing

* cleanup circle, fix cuda repair

* update circle

* update circle

* decouple python bindings from core libraries
2025-07-14 17:17:33 -07:00
Awni Hannun
49114f28ab fix flaky test (#2371) 2025-07-14 17:16:18 -07:00
Awni Hannun
e7d2ebadd2 [CUDA] Affine quantize (#2354)
* affine quantize and dequantize kernels

* format

* fix

* format
2025-07-14 15:45:44 -07:00
Awni Hannun
e569803d7c update linux build (#2370) 2025-07-14 15:13:56 -07:00
Cheng
d34f887abc Add Primitive::name and remove Primitive::print (#2365) 2025-07-14 14:06:35 -07:00
Angelos Katharopoulos
5201df5030 Fix imag() vjp (#2367) 2025-07-14 13:11:16 -07:00
Cheng
2d3c26c565 [CUDA] Do not put kernels in annoymous namespace (#2362) 2025-07-12 14:24:45 -07:00
Cheng
6325f60d52 [CUDA] Bundle CCCL for JIT compilation (#2357)
* Ship CCCL for JIT compilation

* Remove cexpf
2025-07-11 18:45:37 -07:00
Awni Hannun
42cc9cfbc7 fix copy dispatch (#2360) 2025-07-11 10:59:35 -07:00
Cheng
8347575ba1 [CUDA] Implement Scan kernel (#2347)
* Contiguous scan

* Strided scan

* Enable tests

* Fix failing logaddexp test

* Use cexpf in Metal
2025-07-10 16:54:12 -07:00
Angelos Katharopoulos
b6eec20260 Fix edge check in qmm_n QuantizedLoader (#2355) 2025-07-10 16:28:50 -07:00
Angelos Katharopoulos
0eb035b4b1 Fix type promotion in Adam with bias correction (#2350) 2025-07-10 11:14:42 -07:00
Cheng
afb9817599 [CUDA] Put version in ptx cache dir path (#2352) 2025-07-10 07:24:21 -07:00
Cheng
8fb3e7a26c [CUDA] Set current device before cudaGraphLaunch (#2351) 2025-07-10 07:24:02 -07:00
jhavukainen
8c7bc30ce4 Align mlx::core::min op nan propagation with NumPy (#2346) 2025-07-10 06:20:43 -07:00
Cheng
85873cb162 [CUDA] Do vectorized store/load in contiguous elementwise ops (#2342)
* Do vectorized store/load in unary ops

* Do vectorized store/load in binary_two ops

* Do vectorized store/load in copy ops

* Do vectorized store/load in ternary ops

* Use int32_t for IdxT

* binary => binary_two in binary_two.cu

* Fix tests on large arrays

* Use uint as index type

* Contig uses uint as index and non-contig uses int
2025-07-09 18:48:43 -07:00
Awni Hannun
e14ee12491 add zero for argsort vjp (#2345) 2025-07-09 14:37:14 -07:00
jhavukainen
8b9a3f3cea Align mlx::core::max op nan propagation with NumPy (#2339)
* Make max op NaN propagation rules align with numpy

* Adding benchmarks and testing for max op nanpropagation

* Pre-commit formatting

* Fix max complex64 nan propagation and add test

* Improve the cpp unittest

* Only check nans on non-integral types in simd_reduce_impl.

* Cleanup using namespace alias

* Add cpu Max nanpropagation. Fix a small fib in cpu max dispatch data types for int8/int16.

* Make the max nanpropagation test more meaningful for integer types

* Remove tuple unpacking syntax to comply with earlier python versions. Add cuda skip to nanpropagation tests, fix cuda implementation in a separate PR.
2025-07-09 11:26:27 -07:00
Awni Hannun
fb4e8b896b patch bump (#2343) 2025-07-08 14:26:07 -07:00
Cheng
2ca533b279 Fix compilation with CUDA 11 (#2331) 2025-07-07 20:00:43 -07:00
Angelos Katharopoulos
4a9b29a875 MoE backward improvements (#2335) 2025-07-07 17:59:53 -07:00
Awni Hannun
a4fcc893cd auto build linux release (#2341) 2025-07-07 09:29:23 -07:00
Cheng
9d10239af7 [CUDA] Do vectorized store/load in binary ops (#2330) 2025-07-07 08:44:14 -07:00
Cheng
19facd4b20 Build with all cpu cores by default (#2336) 2025-07-07 06:06:45 -07:00
Angelos Katharopoulos
f5299f72cd Fix layernorm race condition (#2340) 2025-07-07 06:06:01 -07:00
Cheng
0e0d9ac522 [CUDA] Add MLX_CUDA_GRAPH_CACHE_SIZE env for setting graph cache size (#2329) 2025-07-05 08:33:29 -07:00
Awni Hannun
8917022deb fix graphs for older cuda (#2328) 2025-07-02 19:37:58 -07:00
Awni Hannun
ec0d5db67b [CUDA] Switch to CUDA graphs (#2317)
* cuda graph prototype

fix signal bug + start to add dependencies

capture more

capture more ops

remaining ops

fix reduce and rope deps

add concurrent context

try update, but not working

cosistent topology order

use node api

use node api directly to reduce overhead

fix bug

use kernels in unary

cache graph

format

fix synchronization

format

* comment
2025-07-02 15:59:13 -07:00
Cheng
e76e9b87f0 Fix compilation error from integral_constant (#2326) 2025-07-02 06:04:38 -07:00
Awni Hannun
cfb6a244ea allow parameters to be deleted (#2325) 2025-07-01 21:27:23 -07:00
Awni Hannun
58f3860306 patch bump (#2324) 2025-07-01 12:12:16 -07:00
Awni Hannun
dd4f53db63 use fp32 for testing, add more complex ops (#2322) 2025-07-01 07:30:00 -07:00
Angelos Katharopoulos
3d5e17e507 MLX_SWITCH macros to templates (#2320) 2025-07-01 01:33:44 -07:00
Awni Hannun
33bf1a244b Fix module update in strict mode (#2321)
* fix module update in strict mode

* allow GELU to be pickled
2025-06-29 11:12:29 -07:00
Angelos Katharopoulos
772f471ff2 [CUDA] Fix reductions (#2314) 2025-06-27 12:59:20 -07:00
Angelos Katharopoulos
2c11d10f8d Split broadcast so it is always fused in compile (#2318) 2025-06-26 22:08:18 -07:00
Angelos Katharopoulos
656ed7f780 Fix get 2d grid dims (#2316) 2025-06-25 13:03:09 -07:00
Awni Hannun
81bb9a2a9e Compile float64 functions on CPU (#2311) 2025-06-24 10:18:52 -07:00
Angelos Katharopoulos
5adf185f86 Fix update_modules() when providing a subset (#2308) 2025-06-20 17:19:46 -07:00
Awni Hannun
c9a9180584 Cuda perf tuning (#2307)
* perf tuning

* fix adding inputs arrays in matmul / srot

* format

* fix
2025-06-20 14:50:57 -07:00
Awni Hannun
76831ed83d Build CUDA release in Circle (#2306)
* cuda release

* add license
2025-06-19 15:26:36 -07:00
Angelos Katharopoulos
b3d7b85376 Make ptx cache settable by environment variable (#2304) 2025-06-17 23:55:56 -07:00
Awni Hannun
cad5c0241c [CUDA] synch properly waits for all tasks to finish and clear (#2303)
* cuda synch properly waits for all tasks to finish and clear

* fix copy
2025-06-17 12:03:25 -07:00
Awni Hannun
b8022c578a divmod, partition, sort fixes (#2302) 2025-06-16 18:49:32 -07:00
Awni Hannun
bc53f8293f Cuda bug fixes 2 (#2298)
* more bug fixes

* more bug fixes

* format
2025-06-16 13:14:46 -07:00
Awni Hannun
c552ff2451 [CUDA] Fix back-end bugs and enable corresponding tests (#2296)
* Fix some cuda back-end bugs and enable corresponding tests

* more fixes

* enable more tests

* format
2025-06-16 08:45:40 -07:00
Awni Hannun
4fda5fbdf9 add python testing for cuda with ability to skip list of tests (#2295) 2025-06-15 10:56:48 -07:00
Angelos Katharopoulos
580776559b RoPE for CUDA (#2293)
* First working CUDA rope

* Fix random
2025-06-15 06:08:07 -07:00
Awni Hannun
a14aaa7c9d Fix cuda arg reduce (#2291) 2025-06-14 17:54:00 -07:00
Awni Hannun
a6d780154f fix cuda gemm for bf16 (#2288) 2025-06-13 22:10:46 -07:00
Awni Hannun
6871e2eeb7 fix cuda jit (#2287) 2025-06-13 19:21:46 -07:00
Awni Hannun
8402a2acf4 Fix complex power and print (#2286)
* fix complex power and print

* fix complex matmul shape
2025-06-13 11:13:00 -07:00
Jagrit Digani
fddb6933e1 Collection of refactors (#2274)
* Refactor gemv into a function

* Refactor splitk step 1

* Refactor split k axpby

* Rearrange steel_gemm_regular

* Redirect steel_gemm_regular

* Add axpby routing to steel_matmul_regular

* Refactor AddMM step 1

* Redirect steel_gemm

* Update addmm

* Comments and format

* Some cleanup

* Add architecture gen to device

* Update no copy condition in normalization to account for axis size 1
2025-06-13 10:44:56 -07:00
Cheng
c8b4787e4e CUDA backend: indexing ops (#2277) 2025-06-12 21:44:19 -07:00
Awni Hannun
2188199ff8 [CUDA] ternary with select op (#2283)
* cuda ternary with select op

* comment + fix

* fix
2025-06-12 20:24:43 -07:00
Awni Hannun
aa07429bad Fix cuda build (#2284) 2025-06-12 17:48:05 -07:00
Awni Hannun
918761a25a [CUDA] RMSNorm and VJP (#2280)
* rms norm start

* nit
2025-06-12 17:09:49 -07:00
Cheng
a4fc671d3e CUDA backend: compile (#2276)
* CUDA backend: compile

* Rename kernels/ to device/
2025-06-12 17:08:39 -07:00
Awni Hannun
f5f65ef48c Make sliceUpdate general (#2282)
* Make sliceUpdate general

* fix
2025-06-12 16:48:54 -07:00
Cheng
c2dd81a8aa Fix warnings from latest CUDA toolkit (#2275) 2025-06-12 06:03:01 -07:00
Cheng
d7e680ffe4 CUDA backend: layernorm (#2271) 2025-06-11 15:48:32 -07:00
Cheng
c371baf53a CUDA backend: softmax (#2272) 2025-06-11 13:55:22 -07:00
Cheng
ccf78f566c CUDA backend: argreduce (#2270) 2025-06-11 13:26:17 -07:00
Cheng
c9fa68664a CUDA backend: reduce (#2269) 2025-06-11 11:22:25 -07:00
1589 changed files with 609666 additions and 5789 deletions

View File

@@ -7,15 +7,9 @@ parameters:
nightly_build:
type: boolean
default: false
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
jobs:
build_documentation:
@@ -24,13 +18,14 @@ jobs:
type: boolean
default: false
macos:
xcode: "16.2.0"
resource_class: m2pro.medium
xcode: "26.0.0"
resource_class: m4pro.medium
steps:
- checkout
- run:
name: Install
command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.9
brew install doxygen
python3.9 -m venv env
@@ -38,7 +33,7 @@ jobs:
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
@@ -70,9 +65,9 @@ jobs:
git push -f origin gh-pages
linux_build_and_test:
docker:
- image: cimg/python:3.9
machine:
image: ubuntu-2204:current
resource_class: large
steps:
- checkout
- run:
@@ -84,37 +79,37 @@ jobs:
- run:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e ".[dev]" -v
- run:
name: Generate package stubs
command: |
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
source .venv/bin/activate
python -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run:
name: Build CPP only
command: |
mkdir -p build && cd build
source .venv/bin/activate
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
- run:
@@ -125,7 +120,7 @@ jobs:
parameters:
xcode_version:
type: string
default: "16.2.0"
default: "26.0.0"
macosx_deployment_target:
type: string
default: ""
@@ -133,57 +128,56 @@ jobs:
xcode: << parameters.xcode_version >>
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
resource_class: m4pro.medium
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.9
brew install openmpi
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
xcodebuild -downloadComponent MetalToolchain
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi uv
- run:
name: Install Python package
command: |
source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v
uv venv --python 3.9
uv pip install \
nanobind==2.4.0 \
cmake \
numpy \
torch \
tensorflow \
unittest-xml-reporting
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- run:
name: Run Python tests
command: |
source env/bin/activate
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run:
name: Build example extension
command: |
source env/bin/activate
source .venv/bin/activate
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext -j8
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project python test.py
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source env/bin/activate
source .venv/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
@@ -192,7 +186,7 @@ jobs:
- run:
name: Build small binary
command: |
source env/bin/activate
source .venv/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
@@ -204,36 +198,74 @@ jobs:
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
uv run --no-project python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
cuda_build_and_test:
parameters:
image_date:
type: string
default: "2023.11.1"
machine:
image: linux-cuda-12:default
image: "linux-cuda-12:<< parameters.image_date >>"
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- restore_cache:
keys:
- cuda-<< parameters.image_date >>-{{ arch }}-
- run:
name: Install dependencies
command: |
sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libnccl2 libnccl-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
python -m venv env
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install -e ".[dev]"
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v
- run:
name: Run Python tests
command: |
source env/bin/activate
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=`which nvcc` \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- run:
name: CCache report
command: |
ccache --show-stats
ccache --zero-stats
ccache --max-size 400MB
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
paths:
- /home/circleci/.cache/ccache
build_release:
parameters:
@@ -242,7 +274,7 @@ jobs:
default: "3.9"
xcode_version:
type: string
default: "16.2.0"
default: "26.0.0"
build_env:
type: string
default: ""
@@ -251,7 +283,7 @@ jobs:
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: m2pro.medium
resource_class: m4pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps:
@@ -259,11 +291,15 @@ jobs:
- run:
name: Install dependencies
command: |
brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
xcodebuild -downloadComponent MetalToolchain
mkdir -p ~/miniconda3
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
source ~/miniconda3/bin/activate
conda init --all
conda create -n env python=<< parameters.python_version >> -y
conda activate env
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
@@ -273,30 +309,38 @@ jobs:
- run:
name: Install Python package
command: |
source env/bin/activate
conda activate env
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
conda activate env
pip install typing_extensions
python setup.py generate_stubs
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
python -m build -w
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
source env/bin/activate
conda activate env
twine upload dist/*
- store_artifacts:
path: dist/
@@ -306,52 +350,100 @@ jobs:
python_version:
type: string
default: "3.9"
extra_env:
build_env:
type: string
default: "DEV_RELEASE=1"
docker:
- image: ubuntu:20.04
default: ""
machine:
image: ubuntu-2204:current
resource_class: large
steps:
- checkout
- run:
name: Build wheel
command: |
PYTHON=python<< parameters.python_version >>
apt-get update
apt-get upgrade -y
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
apt-get install -y apt-utils
apt-get install -y software-properties-common
add-apt-repository -y ppa:deadsnakes/ppa
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get install -y build-essential git
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
TZ=Etc/UTC sudo apt-get -y install tzdata
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v
<< parameters.build_env >> pip install ".[dev]" -v
pip install typing_extensions
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
python setup.py generate_stubs
python setup.py clean --all
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
bash python/scripts/repair_linux.sh
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload packages
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
build_cuda_release:
parameters:
build_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: xlarge
steps:
- checkout
- run:
name: Upload package
name: Build wheel
command: |
source env/bin/activate
twine upload wheelhouse/*
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip
pip install auditwheel
pip install patchelf
pip install build
pip install twine
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build -w
bash python/scripts/repair_cuda.sh
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
@@ -363,22 +455,23 @@ workflows:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test
- cuda_build_and_test
- cuda_build_and_test:
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
- build_documentation
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- build_release:
@@ -392,68 +485,7 @@ workflows:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
xcode_version: ["26.0.0"]
- build_documentation:
filters:
tags:
@@ -461,6 +493,25 @@ workflows:
branches:
ignore: /.*/
upload-docs: true
- build_linux_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
build_env: ["PYPI_RELEASE=1"]
prb:
when:
@@ -476,11 +527,14 @@ workflows:
requires: [ hold ]
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test:
requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
nightly_build:
when:
and:
@@ -492,58 +546,18 @@ workflows:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
weekly_build:
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
- build_cuda_release
build_dev_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.weekly_build >>
- << pipeline.parameters.test_release >>
jobs:
- build_release:
matrix:
@@ -551,76 +565,13 @@ workflows:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.linux_release >>
jobs:
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]
build_env: ["DEV_RELEASE=1"]
- build_cuda_release:
matrix:
parameters:
build_env: ["DEV_RELEASE=1"]

View File

@@ -19,11 +19,17 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>
# Organizations
MLX has received contributions from the following companies:
- NVIDIA Corporation & Affiliates
# Third-Party Software
MLX leverages several third-party software, listed here together with

View File

@@ -41,7 +41,9 @@ option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
# --------------------- Processor tests -------------------------
message(
@@ -64,10 +66,17 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
endif()
else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif()
if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
endif()
endif()
# ----------------------------- Lib -----------------------------
@@ -131,6 +140,12 @@ elseif(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
@@ -234,12 +249,16 @@ target_include_directories(
# Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
if(USE_SYSTEM_FMT)
find_package(fmt REQUIRED)
else()
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
endif()
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)

View File

@@ -11,10 +11,10 @@ brought to you by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
@@ -68,18 +68,23 @@ in the documentation.
## Installation
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
macOS, run:
**With `pip`**:
```
```bash
pip install mlx
```
**With `conda`**:
To install the CUDA backend on Linux, run:
```bash
pip install mlx[cuda]
```
conda install -c conda-forge mlx
To install a CPU-only Linux package, run:
```bash
pip install mlx[cpu]
```
Checkout the

View File

@@ -192,6 +192,22 @@ void time_reductions() {
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
TIME(argmin_along_1);
auto indices = mx::array({1});
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
std::vector<int> axes{0};
auto b = scatter(a, {indices}, updates, axes);
mx::eval(b);
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
TIME(max_along_0);
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
TIME(max_along_1);
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
TIME(min_along_0);
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
TIME(min_along_1);
}
void time_gather_scatter() {

View File

@@ -5,6 +5,7 @@ import os
import time
import torch
import torch.cuda
import torch.mps
@@ -44,8 +45,10 @@ def bench(f, *args):
def sync_if_needed(x):
if x.device != torch.device("cpu"):
if x.device == torch.device("mps"):
torch.mps.synchronize()
elif x.device == torch.device("cuda"):
torch.cuda.synchronize()
@torch.no_grad()
@@ -99,6 +102,14 @@ def reduction(op, axis, x):
sync_if_needed(x)
@torch.no_grad()
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
sync_if_needed(x)
@torch.no_grad()
def softmax(axis, x):
ys = []
@@ -340,7 +351,11 @@ if __name__ == "__main__":
args.axis.pop(0)
torch.set_num_threads(1)
device = "cpu" if args.cpu else "mps"
device = "mps"
if torch.cuda.is_available():
device = "cuda"
if args.cpu:
device = "cpu"
types = args.dtype
if not types:
@@ -460,5 +475,8 @@ if __name__ == "__main__":
elif args.benchmark == "selu":
print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else:
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")

View File

@@ -51,6 +51,20 @@ def time_maximum():
time_fn(mx.maximum, a, b)
def time_max():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.max, a, 0)
def time_min():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.min, a, 0)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
@@ -108,6 +122,8 @@ if __name__ == "__main__":
time_add()
time_matmul()
time_min()
time_max()
time_maximum()
time_exp()
time_negative()

54
cmake/FindNCCL.cmake Normal file
View File

@@ -0,0 +1,54 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

4
docs/build/html/.buildinfo vendored Normal file
View File

@@ -0,0 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: 6e9fcd3fd9a477c32d79521f0d5d7188
tags: 645f666f9bcd5a90fca523b33c5a78b7

BIN
docs/build/html/_images/capture.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

BIN
docs/build/html/_images/schema.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 746 KiB

7
docs/build/html/_sources/cpp/ops.rst vendored Normal file
View File

@@ -0,0 +1,7 @@
.. _cpp_ops:
Operations
==========
.. doxygengroup:: ops
:content-only:

View File

@@ -0,0 +1,445 @@
.. _custom_metal_kernels:
Custom Metal Kernels
====================
MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example
--------------
.. currentmodule:: mlx.core
Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
def exp_elementwise(a: mx.array):
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
)
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Every time you make a kernel, a new Metal library is created and possibly
JIT compiled. To reduce the overhead from that, build the kernel once with
:func:`fast.metal_kernel` and then use it many times.
.. note::
Only pass the body of the Metal kernel in ``source``. The function
signature is generated automatically.
The full function signature will be generated using:
* The shapes/dtypes of ``inputs``
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
so we will add ``const device float16_t* inp`` to the signature.
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
in ``source``.
* The list of ``output_dtypes``
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
so we add ``device float16_t* out``.
* Template parameters passed using ``template``
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
and instantiates the template with ``custom_kernel_myexp_float<float>``.
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
These will be added as function arguments.
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
Putting this all together, the generated function signature for ``myexp`` is as follows:
.. code-block:: cpp
template <typename T>
[[kernel]] void custom_kernel_myexp_float(
const device float16_t* inp [[buffer(0)]],
device float16_t* out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
}
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
``threadgroup`` size threadgroups. For optimal performance, each thread group
dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
generated code for debugging purposes.
Using Shape/Strides
-------------------
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
is ``True`` by default. This will copy the array inputs if needed
before the kernel is launched to ensure that the memory layout is row
contiguous. Generally this makes writing the kernel easier, since we don't
have to worry about gaps or the ordering of the dims when indexing.
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
present in ``source``. We can then use MLX's built in indexing utils to fetch
the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without
relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source,
ensure_row_contiguous=False,
)
def exp_elementwise(a: mx.array):
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
)
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
# make non-contiguous
a = a[::2]
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Complex Example
-----------------------------
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
We'll start with the following MLX implementation using standard ops:
.. code-block:: python
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
@mx.custom_function
def grid_sample(x, grid):
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
assert D == 2, "Last dim of `grid` must be size 2."
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs[0]
For a reasonably sized input such as:
.. code-block:: python
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement:
``55.7ms -> 6.7ms => 8x speed up``
Grid Sample VJP
---------------
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
define its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra :func:`fast.metal_kernel` features:
* ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
* ``atomic_outputs=True``
Designate all of the kernel outputs as ``atomic`` in the function signature.
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
We can then implement the backwards pass as follows:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
T gix_mult = W / 2;
T giy_mult = H / 2;
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
assert D == 2, "Last dim of `grid` must be size 2."
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
There's an even larger speed up for the vjp:
``676.4ms -> 16.7ms => 40x speed up``

View File

@@ -0,0 +1,811 @@
Custom Extensions in MLX
========================
You can extend MLX with custom operations on the CPU or GPU. This guide
explains how to do that with a simple example.
Introducing the Example
-----------------------
Let's say you would like an operation that takes in two arrays, ``x`` and
``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
and then adds them together to get the result ``z = alpha * x + beta * y``.
You can do that in MLX directly:
.. code-block:: python
import mlx.core as mx
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
This function performs that operation while leaving the implementation and
function transformations to MLX.
However, you may want to customize the underlying implementation, perhaps to
make it faster. In this tutorial we will go through adding custom extensions.
It will cover:
* The structure of the MLX library.
* Implementing a CPU operation.
* Implementing a GPU operation using metal.
* Adding the ``vjp`` and ``jvp`` function transformation.
* Building a custom extension and binding it to python.
Operations and Primitives
-------------------------
Operations in MLX build the computation graph. Primitives provide the rules for
evaluating and transforming the graph. Let's start by discussing operations in
more detail.
Operations
^^^^^^^^^^^
Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
C++:
.. code-block:: C++
/**
* Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Use NumPy-style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s = {} // Stream on which to schedule the operation
);
The simplest way to implement this is with existing operations:
.. code-block:: C++
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Scale x and y on the provided stream
auto ax = multiply(array(alpha), x, s);
auto by = multiply(array(beta), y, s);
// Add and return
return add(ax, by, s);
}
The operations themselves do not contain the implementations that act on the
data, nor do they contain the rules of transformations. Rather, they are an
easy to use interface that use :class:`Primitive` building blocks.
Primitives
^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create output arrays given input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
more concrete:
.. code-block:: C++
class Axpby : public Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta){};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
* for the given inputs and populate the output array.
*
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
/** The Jacobian-vector product. */
std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
/**
* The primitive must know how to vectorize itself across
* the given axes. The output is a pair containing the array
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
/** The name of primitive. */
const char* name() const override {
return "Axpby";
}
/** Equivalence check **/
bool is_equivalent(const Primitive& other) const override;
private:
float alpha_;
float beta_;
};
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
implementations of how the output array is produced given the inputs through
:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
:meth:`Axpby::vmap`.
Using the Primitive
^^^^^^^^^^^^^^^^^^^
Operations can use this :class:`Primitive` to add a new :class:`array` to the
computation graph. An :class:`array` can be constructed by providing its data
type, shape, the :class:`Primitive` that computes it, and the :class:`array`
inputs that are passed to the primitive.
Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
.. code-block:: C++
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Promote dtypes between x and y as needed
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = issubdtype(promoted_dtype, float32)
? promoted_dtype
: promote_types(promoted_dtype, float32);
// Cast x and y up to the determined dtype (on the same stream s)
auto x_casted = astype(x, out_dtype, s);
auto y_casted = astype(y, out_dtype, s);
// Broadcast the shapes of x and y (on the same stream s)
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
auto out_shape = broadcasted_inputs[0].shape();
// Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs
return array(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
This operation now handles the following:
#. Upcast inputs and resolve the output data type.
#. Broadcast the inputs and resolve the output shape.
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
#. Construct the output :class:`array` using the primitive and the inputs.
Implementing the Primitive
--------------------------
No computation happens when we call the operation alone. The operation only
builds the computation graph. When we evaluate the output array, MLX schedules
the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
.. warning::
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
no memory has been allocated for the output array. It falls on the implementation
of these functions to allocate memory as needed.
Implementing the CPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by implementing :meth:`Axpby::eval_cpu`.
The method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
point-wise. This is captured in the templated function :meth:`axpby_impl`.
.. code-block:: C++
template <typename T>
void axpby_impl(
const mx::array& x,
const mx::array& y,
mx::array& out,
float alpha_,
float beta_,
mx::Stream stream) {
out.set_data(mx::allocator::malloc(out.nbytes()));
// Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(y);
encoder.set_output_array(out);
// Launch the CPU kernel
encoder.dispatch([x_ptr = x.data<T>(),
y_ptr = y.data<T>(),
out_ptr = out.data<T>(),
size = out.size(),
shape = out.shape(),
x_strides = x.strides(),
y_strides = y.strides(),
alpha_,
beta_]() {
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < size; out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
});
}
Our implementation should work for all incoming floating point arrays.
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
``complex64``. We throw an error if we encounter an unexpected type.
.. code-block:: C++
void Axpby::eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == mx::float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
} else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
}
}
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here.
Implementing the GPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Apple silicon devices address their GPUs using the Metal_ shading language, and
GPU kernels in MLX are written using Metal.
.. note::
Here are some helpful resources if you are new to Metal:
* A walkthrough of the metal compute pipeline: `Metal Example`_
* Documentation for metal shading language: `Metal Specification`_
* Using metal from C++: `Metal-cpp`_
Let's keep the GPU kernel simple. We will launch exactly as many threads as
there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the point-wise operation, and update its assigned
element in the output.
.. code-block:: C++
template <typename T>
[[kernel]] void axpby_general(
device const T* x [[buffer(0)]],
device const T* y [[buffer(1)]],
device T* out [[buffer(2)]],
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]],
constant const int64_t* x_strides [[buffer(6)]],
constant const int64_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) {
// Convert linear indices to offsets in array
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
// Do the operation and update the output
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
We then need to instantiate this template for all floating point types and give
each instantiation a unique host name so we can identify it.
.. code-block:: C++
instantiate_kernel("axpby_general_float32", axpby_general, float)
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
below.
.. code-block:: C++
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
auto& s = stream();
// We get the needed metal device using the stream
auto& d = metal::device(s.device);
// Allocate output memory
out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel
std::stream kname;
kname = "axpby_general_" + type_to_name(out);
// Load the metal library
auto lib = d.get_library("mlx_ext", current_binary_dir());
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal
int ndim = out.ndim();
size_t nelem = out.size();
// Encode input arrays to kernel
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(y, 1);
// Encode output arrays to kernel
compute_encoder.set_output_array(out, 2);
// Encode alpha and beta
compute_encoder.set_bytes(alpha_, 3);
compute_encoder.set_bytes(beta_, 4);
// Encode shape, strides and ndim
compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder.set_bytes(y.strides(), 7);
compute_encoder.set_bytes(ndim, 8);
// We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed
size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
// Fix the 3D size of each threadgroup (in terms of threads)
MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
// Fix the 3D size of the launch grid (in terms of threads)
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
A few things to note about MLX and Metal before moving on. MLX keeps track of
the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
associated. We rely on :meth:`d.get_command_encoder` to give us the active
metal compute command encoder instead of building a new one and calling
:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
pipelines) to the active command buffer until some specified limit is hit or
the command buffer needs to be flushed for synchronization.
Primitive Transforms
^^^^^^^^^^^^^^^^^^^^^
Next, let's add implementations for transformations in a :class:`Primitive`.
These transformations can be built on top of other operations, including the
one we just defined:
.. code-block:: C++
/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can be built with ops
// that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the
// jvp is just the tangent scaled by alpha
// Similarly, if argnums = {1}, the jvp is just the tangent
// scaled by beta
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
}
// If argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
}
}
.. code-block:: C++
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<int>& /* unused */) {
// Reverse mode diff
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
}
return vjps;
}
Note, a transformation does not need to be fully defined to start using
the :class:`Primitive`.
.. code-block:: C++
/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("[Axpby] vmap not implemented.");
}
Building and Binding
--------------------
Let's look at the overall directory structure first.
| extensions
| ├── axpby
| │ ├── axpby.cpp
| │ ├── axpby.h
| │ └── axpby.metal
| ├── mlx_sample_extensions
| │ └── __init__.py
| ├── bindings.cpp
| ├── CMakeLists.txt
| └── setup.py
* ``extensions/axpby/`` defines the C++ extension library
* ``extensions/mlx_sample_extensions`` sets out the structure for the
associated Python package
* ``extensions/bindings.cpp`` provides Python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
Python bindings
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
the Python package
Binding to Python
^^^^^^^^^^^^^^^^^^
We use nanobind_ to build a Python API for the C++ library. Since bindings for
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
already provided, adding our :meth:`axpby` is simple.
.. code-block:: C++
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
"alpha"_a,
"beta"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"(
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
Args:
x (array): Input array.
y (array): Input array.
alpha (float): Scaling factor for ``x``.
beta (float): Scaling factor for ``y``.
Returns:
array: ``alpha * x + beta * y``
)");
}
Most of the complexity in the above example comes from additional bells and
whistles such as the literal names and doc-strings.
.. warning::
:mod:`mlx.core` must be imported before importing
:mod:`mlx_sample_extensions` as defined by the nanobind module above to
ensure that the casters for :mod:`mlx.core` components like
:class:`mlx.core.array` are available.
.. _Building with CMake:
Building with CMake
^^^^^^^^^^^^^^^^^^^^
Building the C++ extension library only requires that you ``find_package(MLX
CONFIG)`` and then link it to your library.
.. code-block:: cmake
# Add library
add_library(mlx_ext)
# Add sources
target_sources(
mlx_ext
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)
# Add include headers
target_include_directories(
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)
# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)
We also need to build the attached Metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package).
Here is what that looks like in practice:
.. code-block:: cmake
# Build metallib
if(MLX_BUILD_METAL)
mlx_build_metallib(
TARGET mlx_ext_metallib
TITLE mlx_ext
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
)
add_dependencies(
mlx_ext
mlx_ext_metallib
)
endif()
Finally, we build the nanobind_ bindings
.. code-block:: cmake
nanobind_add_module(
_ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(_ext PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif()
Building with ``setuptools``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once we have set out the CMake build rules as described above, we can use the
build utilities defined in :mod:`mlx.extension`:
.. code-block:: python
from mlx import extension
from setuptools import setup
if __name__ == "__main__":
setup(
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"],
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev":[]},
zip_safe=False,
python_requires=">=3.8",
)
.. note::
We treat ``extensions/mlx_sample_extensions`` as the package directory
even though it only contains a ``__init__.py`` to ensure the following:
* :mod:`mlx.core` must be imported before importing :mod:`_ext`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
To build the package, first install the build dependencies with ``pip install
-r requirements.txt``. You can then build inplace for development using
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
This results in the directory structure:
| extensions
| ├── mlx_sample_extensions
| │ ├── __init__.py
| │ ├── libmlx_ext.dylib # C++ extension library
| │ ├── mlx_ext.metallib # Metal library
| │ └── _ext.cpython-3x-darwin.so # Python Binding
| ...
When you try to install using the command ``python -m pip install .`` (in
``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
copied along with the Python binding since they are specified as
``package_data``.
Usage
-----
After installing the extension as described above, you should be able to simply
import the Python package and play with it as you would any other MLX operation.
Let's look at a simple script and its results:
.. code-block:: python
import mlx.core as mx
from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c is correct: {mx.all(c == 6.0).item()}")
Output:
.. code-block::
c shape: [3, 4]
c dtype: float32
c is correct: True
Results
^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we first defined.
.. code-block:: python
import mlx.core as mx
from mlx_sample_extensions import axpby
import time
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
M = 4096
N = 4096
x = mx.random.normal((M, N))
y = mx.random.normal((M, N))
alpha = 4.0
beta = 2.0
mx.eval(x, y)
def bench(f):
# Warm up
for i in range(5):
z = f(x, y, alpha, beta)
mx.eval(z)
# Timed run
s = time.time()
for i in range(100):
z = f(x, y, alpha, beta)
mx.eval(z)
e = time.time()
return 1000 * (e - s) / 100
simple_time = bench(simple_axpby)
custom_time = bench(axpby)
print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see
modest improvements right away!
This operation is now good to be used to build other operations, in
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
:meth:`grad`.
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
.. _nanobind: https://nanobind.readthedocs.io/en/latest/

View File

@@ -0,0 +1,68 @@
Metal Debugger
==============
.. currentmodule:: mlx.core
Profiling is a key step for performance optimization. You can build MLX with
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
* Records source during Metal compilation, for later inspection while
debugging.
* Labels Metal objects such as command queues, improving capture readability.
To build with debugging enabled in Python prepend
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
work.
.. note::
To capture a GPU trace you must run the application with
``MTL_CAPTURE_ENABLED=1``.
.. code-block:: python
import mlx.core as mx
a = mx.random.uniform(shape=(512, 512))
b = mx.random.uniform(shape=(512, 512))
mx.eval(a, b)
trace_file = "mlx_trace.gputrace"
# Make sure to run with MTL_CAPTURE_ENABLED=1 and
# that the path trace_file does not already exist.
mx.metal.start_capture(trace_file)
for _ in range(10):
mx.eval(mx.add(a, b))
mx.metal.stop_capture()
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
has a great overview of all operations. Checkout the `Metal debugger
documentation`_ for more information.
.. image:: ../_static/metal_debugger/capture.png
:class: dark-light
Xcode Workflow
--------------
You can skip saving to a path by running within Xcode. First, generate an
Xcode project using CMake.
.. code-block::
mkdir build && cd build
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
open mlx.xcodeproj
Select the ``metal_capture`` example schema and run.
.. image:: ../_static/metal_debugger/schema.png
:class: dark-light
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger

View File

@@ -0,0 +1,121 @@
.. _mlx_in_cpp:
Using MLX in C++
================
You can use MLX in a C++ project with CMake.
.. note::
This guide is based one the following `example using MLX in C++
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
First install MLX:
.. code-block:: bash
pip install -U mlx
You can also install the MLX Python package from source or just the C++
library. For more information see the :ref:`documentation on installing MLX
<build_and_install>`.
Next make an example program in ``example.cpp``:
.. code-block:: C++
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
auto x = mx::array({1, 2, 3});
auto y = mx::array({1, 2, 3});
std::cout << x + y << std::endl;
return 0;
}
The next step is to setup a CMake file in ``CMakeLists.txt``:
.. code-block:: cmake
cmake_minimum_required(VERSION 3.27)
project(example LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Depending on how you installed MLX, you may need to tell CMake where to
find it.
If you installed MLX with Python, then add the following to the CMake file:
.. code-block:: cmake
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
If you installed the MLX C++ package to a system path, then CMake should be
able to find it. If you installed it to a non-standard location or CMake can't
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
.. code-block:: cmake
set(MLX_ROOT "/path/to/mlx/")
Next, instruct CMake to find MLX:
.. code-block:: cmake
find_package(MLX CONFIG REQUIRED)
Finally, add the ``example.cpp`` program as an executable and link MLX.
.. code-block:: cmake
add_executable(example example.cpp)
target_link_libraries(example PRIVATE mlx)
You can build the example with:
.. code-block:: bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
And run it with:
.. code-block:: bash
./build/example
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
.. list-table:: Package Variables
:widths: 20 20
:header-rows: 1
* - Variable
- Description
* - MLX_FOUND
- ``True`` if MLX is found
* - MLX_INCLUDE_DIRS
- Include directory
* - MLX_LIBRARIES
- Libraries to link against
* - MLX_CXX_FLAGS
- Additional compiler flags
* - MLX_BUILD_ACCELERATE
- ``True`` if MLX was built with Accelerate
* - MLX_BUILD_METAL
- ``True`` if MLX was built with Metal

View File

@@ -0,0 +1,77 @@
.. _linear_regression:
Linear Regression
-----------------
Let's implement a basic linear regression model as a starting point to
learn MLX. First import the core package and setup some problem metadata:
.. code-block:: python
import mlx.core as mx
num_features = 100
num_examples = 1_000
num_iters = 10_000 # iterations of SGD
lr = 0.01 # learning rate for SGD
We'll generate a synthetic dataset by:
1. Sampling the design matrix ``X``.
2. Sampling a ground truth parameter vector ``w_star``.
3. Compute the dependent values ``y`` by adding Gaussian noise to ``X @ w_star``.
.. code-block:: python
# True parameters
w_star = mx.random.normal((num_features,))
# Input examples (design matrix)
X = mx.random.normal((num_examples, num_features))
# Noisy labels
eps = 1e-2 * mx.random.normal((num_examples,))
y = X @ w_star + eps
We will use SGD to find the optimal weights. To start, define the squared loss
and get the gradient function of the loss with respect to the parameters.
.. code-block:: python
def loss_fn(w):
return 0.5 * mx.mean(mx.square(X @ w - y))
grad_fn = mx.grad(loss_fn)
Start the optimization by initializing the parameters ``w`` randomly. Then
repeatedly update the parameters for ``num_iters`` iterations.
.. code-block:: python
w = 1e-2 * mx.random.normal((num_features,))
for _ in range(num_iters):
grad = grad_fn(w)
w = w - lr * grad
mx.eval(w)
Finally, compute the loss of the learned parameters and verify that they are
close to the ground truth parameters.
.. code-block:: python
loss = loss_fn(w)
error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
print(
f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
)
# Should print something close to: Loss 0.00005, |w-w*| = 0.00364
Complete `linear regression
<https://github.com/ml-explore/mlx/tree/main/examples/python/linear_regression.py>`_
and `logistic regression
<https://github.com/ml-explore/mlx/tree/main/examples/python/logistic_regression.py>`_
examples are available in the MLX GitHub repo.

View File

@@ -0,0 +1,382 @@
LLM inference
==============
MLX enables efficient inference of large-ish transformers on Apple silicon
without compromising on ease of use. In this example we will create an
inference script for the Llama family of transformer models in which the model
is defined in less than 200 lines of python.
Implementing the model
----------------------
We will use the neural network building blocks defined in the :mod:`mlx.nn`
module to concisely define the model architecture.
Attention layer
^^^^^^^^^^^^^^^^
We will start with the Llama attention layer which notably uses the RoPE
positional encoding. [1]_ In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to
support efficient inference.
Our implementation uses :class:`mlx.nn.Linear` for all the projections and
:class:`mlx.nn.RoPE` for the positional encoding.
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(dims // num_heads, traditional=True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
# Note that we return the keys and values to possibly be used as a cache
return self.out_proj(values_hat), (keys, values)
Encoder layer
^^^^^^^^^^^^^
The other component of the Llama model is the encoder layer which uses RMS
normalization [2]_ and SwiGLU. [3]_ For RMS normalization we will use
:class:`mlx.nn.RMSNorm` that is already provided in :mod:`mlx.nn`.
.. code-block:: python
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
Full model
^^^^^^^^^^
To implement any Llama model we simply have to combine ``LlamaEncoderLayer``
instances with an :class:`mlx.nn.Embedding` to embed the input tokens.
.. code-block:: python
class Llama(nn.Module):
def __init__(
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dims)
self.layers = [
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
]
self.norm = nn.RMSNorm(dims)
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
def __call__(self, x):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embedding.weight.dtype)
x = self.embedding(x)
for l in self.layers:
x, _ = l(x, mask)
x = self.norm(x)
return self.out_proj(x)
Note that in the implementation above we use a simple list to hold the encoder
layers but using ``model.parameters()`` will still consider these layers.
Generation
^^^^^^^^^^^
Our ``Llama`` module can be used for training but not inference as the
``__call__`` method above processes one input, completely ignores the cache and
performs no sampling whatsoever. In the rest of this subsection, we will
implement the inference function as a python generator that processes the
prompt and then autoregressively yields tokens one at a time.
.. code-block:: python
class Llama(nn.Module):
...
def generate(self, x, temp=1.0):
cache = []
# Make an additive causal mask. We will need that to process the prompt.
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embedding.weight.dtype)
# First we process the prompt x the same way as in __call__ but
# save the caches in cache
x = self.embedding(x)
for l in self.layers:
x, c = l(x, mask=mask)
cache.append(c) # <--- we store the per layer cache in a
# simple python list
x = self.norm(x)
y = self.out_proj(x[:, -1]) # <--- we only care about the last logits
# that generate the next token
y = mx.random.categorical(y * (1/temp))
# y now has size [1]
# Since MLX is lazily evaluated nothing is computed yet.
# Calling y.item() would force the computation to happen at
# this point but we can also choose not to do that and let the
# user choose when to start the computation.
yield y
# Now we parsed the prompt and generated the first token we
# need to feed it back into the model and loop to generate the
# rest.
while True:
# Unsqueezing the last dimension to add a sequence length
# dimension of 1
x = y[:, None]
x = self.embedding(x)
for i in range(len(cache)):
# We are overwriting the arrays in the cache list. When
# the computation will happen, MLX will be discarding the
# old cache the moment it is not needed anymore.
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
x = self.norm(x)
y = self.out_proj(x[:, -1])
y = mx.random.categorical(y * (1/temp))
yield y
Putting it all together
^^^^^^^^^^^^^^^^^^^^^^^
We now have everything we need to create a Llama model and sample tokens from
it. In the following code, we randomly initialize a small Llama model, process
6 tokens of prompt and generate 10 tokens.
.. code-block:: python
model = Llama(num_layers=12, vocab_size=8192, dims=512, mlp_dims=1024, num_heads=8)
# Since MLX is lazily evaluated nothing has actually been materialized yet.
# We could have set the `dims` to 20_000 on a machine with 8GB of RAM and the
# code above would still run. Let's actually materialize the model.
mx.eval(model.parameters())
prompt = mx.array([[1, 10, 8, 32, 44, 7]]) # <-- Note the double brackets because we
# have a batch dimension even
# though it is 1 in this case
generated = [t for i, t in zip(range(10), model.generate(prompt, 0.8))]
# Since we haven't evaluated anything, nothing is computed yet. The list
# `generated` contains the arrays that hold the computation graph for the
# full processing of the prompt and the generation of 10 tokens.
#
# We can evaluate them one at a time, or all together. Concatenate them or
# print them. They would all result in very similar runtimes and give exactly
# the same results.
mx.eval(generated)
Converting the weights
----------------------
This section assumes that you have access to the original Llama weights and the
SentencePiece model that comes with them. We will write a small script to
convert the PyTorch weights to MLX compatible ones and write them in a NPZ file
that can be loaded directly by MLX.
.. code-block:: python
import argparse
from itertools import starmap
import numpy as np
import torch
def map_torch_to_mlx(key, value):
if "tok_embedding" in key:
key = "embedding.weight"
elif "norm" in key:
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
key = key.replace("wq", "query_proj")
key = key.replace("wk", "key_proj")
key = key.replace("wv", "value_proj")
key = key.replace("wo", "out_proj")
elif "w1" in key or "w2" in key or "w3" in key:
# The FFN is a separate submodule in PyTorch
key = key.replace("feed_forward.w1", "linear1")
key = key.replace("feed_forward.w3", "linear2")
key = key.replace("feed_forward.w2", "linear3")
elif "output" in key:
key = key.replace("output", "out_proj")
elif "rope" in key:
return None, None
return key, value.numpy()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument("torch_weights")
parser.add_argument("output_file")
args = parser.parse_args()
state = torch.load(args.torch_weights)
np.savez(
args.output_file,
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
)
Weight loading and benchmarking
-------------------------------
After converting the weights to be compatible to our implementation, all that is
left is to load them from disk and we can finally use the LLM to generate text.
We can load numpy format files using the :func:`mlx.core.load` operation.
To create a parameter dictionary from the key/value representation of NPZ files
we will use the :func:`mlx.utils.tree_unflatten` helper method as follows:
.. code-block:: python
from mlx.utils import tree_unflatten
model.update(tree_unflatten(list(mx.load(weight_file).items())))
:meth:`mlx.utils.tree_unflatten` will take keys from the NPZ file that look
like ``layers.2.attention.query_proj.weight`` and will transform them to
.. code-block:: python
{"layers": [..., ..., {"attention": {"query_proj": {"weight": ...}}}]}
which can then be used to update the model. Note that the method above incurs
several unnecessary copies from disk to numpy and then from numpy to MLX. It
will be replaced in the future with direct loading to MLX.
You can download the full example code in `mlx-examples`_. Assuming, the
existence of ``weights.pth`` and ``tokenizer.model`` in the current working
directory we can play around with our inference script as follows (the timings
are representative of an M1 Ultra and the 7B parameter Llama model):
.. code-block:: bash
$ python convert.py weights.pth llama-7B.mlx.npz
$ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely'
[INFO] Loading model from disk: 5.247 s
Press enter to start generation
------
, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down,
------
[INFO] Prompt processing: 0.437 s
[INFO] Full generation: 4.330 s
We observe that 4.3 seconds are required to generate 100 tokens and 0.4 seconds
of those are spent processing the prompt. This amounts to a little over **39 ms
per token**.
By running with a much bigger prompt we can see that the per token generation
time as well as the prompt processing time remains almost constant.
.. code-block:: bash
$ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
[INFO] Loading model from disk: 5.247 s
Press enter to start generation
------
take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not
------
[INFO] Prompt processing: 0.579 s
[INFO] Full generation: 4.690 s
$ python llama.py --num-tokens 500 llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
[INFO] Loading model from disk: 5.628 s
Press enter to start generation
------
take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not reply, but still went on looking at the ground, and took hold of his bundle with a nervous trembling. I waited some time, and then resumed. “It is of no use to say you would not understand, if I were to tell you,” said he. “I have not told you why I am waiting for him,” said I. “And I am sure I should not understand,” replied he. “I will tell you then,” said I, “and, perhaps, you would not be surprised.” “No matter,” said he, “I shall be surprised anyhow; so tell me why you are waiting for him.” “He is my friend,” said I. “Yes,” said he, with a slight smile, “I know.” “He has been kind to me,” said I, “and I am waiting for him. I want to see him, and could have waited as I am now, for a much longer time.” “He will not soon come,” said he. “Unless he sees you here, he will not know of your having waited, and he will be very unlikely to come.” “No matter,” said I, “I shall wait for him.” “This is a strange thing,” said he, still with the same amused smile. “How did you know,” said I, “that he was coming? How should you be waiting?” “That is my secret,” said he. “And you expect him?” “Yes,” said I. “Are you disappointed then, if he does not come?” “No,” said I, “it is his secret, not mine.” “If he comes,” said he, “do you mean to go straight away?” “Yes,” said I, “I cannot be happy if I do not go straight away after him.” “Did you know this place before?” asked he. “Yes,” said I. “Is there any shop to buy food here?” “
------
[INFO] Prompt processing: 0.633 s
[INFO] Full generation: 21.475 s
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx-examples`_.
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama
.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
Roformer: Enhanced transformer with rotary position embedding. arXiv
preprint arXiv:2104.09864.
.. [2] Zhang, B. and Sennrich, R., 2019. Root mean square layer normalization.
Advances in Neural Information Processing Systems, 32.
.. [3] Shazeer, N., 2020. Glu variants improve transformer. arXiv preprint
arXiv:2002.05202.

View File

@@ -0,0 +1,134 @@
.. _mlp:
Multi-Layer Perceptron
----------------------
In this example we'll learn to use ``mlx.nn`` by implementing a simple
multi-layer perceptron to classify MNIST.
As a first step import the MLX packages we need:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
The model is defined as the ``MLP`` class which inherits from
:class:`mlx.nn.Module`. We follow the standard idiom to make a new module:
1. Define an ``__init__`` where the parameters and/or submodules are setup. See
the :ref:`Module class docs<module_class>` for more information on how
:class:`mlx.nn.Module` registers parameters.
2. Define a ``__call__`` where the computation is implemented.
.. code-block:: python
class MLP(nn.Module):
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = mx.maximum(l(x), 0.0)
return self.layers[-1](x)
We define the loss function which takes the mean of the per-example cross
entropy loss. The ``mlx.nn.losses`` sub-package has implementations of some
commonly used loss functions.
.. code-block:: python
def loss_fn(model, X, y):
return mx.mean(nn.losses.cross_entropy(model(X), y))
We also need a function to compute the accuracy of the model on the validation
set:
.. code-block:: python
def eval_fn(model, X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
Next, setup the problem parameters and load the data. To load the data, you need our
`mnist data loader
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
we will import as ``mnist``.
.. code-block:: python
num_layers = 2
hidden_dim = 32
num_classes = 10
batch_size = 256
num_epochs = 10
learning_rate = 1e-1
# Load the data
import mnist
train_images, train_labels, test_images, test_labels = map(
mx.array, mnist.mnist()
)
Since we're using SGD, we need an iterator which shuffles and constructs
minibatches of examples in the training set:
.. code-block:: python
def batch_iterate(batch_size, X, y):
perm = mx.array(np.random.permutation(y.size))
for s in range(0, y.size, batch_size):
ids = perm[s : s + batch_size]
yield X[ids], y[ids]
Finally, we put it all together by instantiating the model, the
:class:`mlx.optimizers.SGD` optimizer, and running the training loop:
.. code-block:: python
# Load the model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
mx.eval(model.parameters())
# Get a function which gives the loss and gradient of the
# loss with respect to the model's trainable parameters
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Instantiate the optimizer
optimizer = optim.SGD(learning_rate=learning_rate)
for e in range(num_epochs):
for X, y in batch_iterate(batch_size, train_images, train_labels):
loss, grads = loss_and_grad_fn(model, X, y)
# Update the optimizer state and model parameters
# in a single call
optimizer.update(model, grads)
# Force a graph evaluation
mx.eval(model.parameters(), optimizer.state)
accuracy = eval_fn(model, test_images, test_labels)
print(f"Epoch {e}: Test accuracy {accuracy.item():.3f}")
.. note::
The :func:`mlx.nn.value_and_grad` function is a convenience function to get
the gradient of a loss with respect to the trainable parameters of a model.
This should not be confused with :func:`mlx.core.value_and_grad`.
The model should train to a decent accuracy (about 95%) after just a few passes
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mnist>`_
is available in the MLX GitHub repo.

93
docs/build/html/_sources/index.rst vendored Normal file
View File

@@ -0,0 +1,93 @@
MLX
===
MLX is a NumPy-like array framework designed for efficient and flexible machine
learning on Apple silicon, brought to you by Apple machine learning research.
The Python API closely follows NumPy with a few exceptions. MLX also has a
fully featured C++ API which closely follows the Python API.
The main differences between MLX and NumPy are:
- **Composable function transformations**: MLX has composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Multi-device**: Operations can run on any of the supported devices (CPU,
GPU, ...)
The design of MLX is inspired by frameworks like `PyTorch
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
`ArrayFire <https://arrayfire.org/>`_. A notable difference from these
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
memory. Operations on MLX arrays can be performed on any of the supported
device types without performing data copies. Currently supported device types
are the CPU and GPU.
.. toctree::
:caption: Install
:maxdepth: 1
install
.. toctree::
:caption: Usage
:maxdepth: 1
usage/quick_start
usage/lazy_evaluation
usage/unified_memory
usage/indexing
usage/saving_and_loading
usage/function_transforms
usage/compile
usage/numpy
usage/distributed
usage/using_streams
usage/export
.. toctree::
:caption: Examples
:maxdepth: 1
examples/linear_regression
examples/mlp
examples/llama-inference
.. toctree::
:caption: Python API Reference
:maxdepth: 1
python/array
python/data_types
python/devices_and_streams
python/export
python/ops
python/random
python/transforms
python/fast
python/fft
python/linalg
python/metal
python/cuda
python/memory_management
python/nn
python/optimizers
python/distributed
python/tree_utils
.. toctree::
:caption: C++ API Reference
:maxdepth: 1
cpp/ops
.. toctree::
:caption: Further Reading
:maxdepth: 1
dev/extensions
dev/metal_debugger
dev/custom_metal_kernels
dev/mlx_in_cpp

345
docs/build/html/_sources/install.rst vendored Normal file
View File

@@ -0,0 +1,345 @@
.. _build_and_install:
Build and Install
=================
Python Installation
-------------------
MLX is available on PyPI. All you have to do to use MLX with your own Apple
silicon computer is
.. code-block:: shell
pip install mlx
To install from PyPI your system must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.9
- macOS >= 13.5
.. note::
MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma)
CUDA
^^^^
MLX has a CUDA backend which you can install with:
.. code-block:: shell
pip install mlx[cuda]
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.0 (Volta)
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.9
CPU-only (Linux)
^^^^^^^^^^^^^^^^
For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell
pip install mlx[cpu]
To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.9
Troubleshooting
^^^^^^^^^^^^^^^
*My OS and Python versions are in the required range but pip still does not find
a matching distribution.*
Probably you are using a non-native Python. The output of
.. code-block:: shell
python -c "import platform; print(platform.processor())"
should be ``arm``. If it is ``i386`` (and you have M series machine) then you
are using a non-native Python. Switch your Python to a native Python. A good
way to do this is with `Conda <https://stackoverflow.com/q/65415996>`_.
Build from source
-----------------
Build Requirements
^^^^^^^^^^^^^^^^^^
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make``
- Xcode >= 15.0 and macOS SDK >= 14.0
.. note::
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
the output of ``uname -p`` is ``x86``, see the :ref:`troubleshooting section <build shell>` below.
Python API
^^^^^^^^^^
.. _python install:
To build and install the MLX python library from source, first, clone MLX from
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
.. code-block:: shell
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Then simply build and install MLX using pip:
.. code-block:: shell
pip install .
For developing, install the package with development dependencies, and use an
editable install:
.. code-block:: shell
pip install -e ".[dev]"
Once the development dependencies are installed, you can build faster with:
.. code-block:: shell
python setup.py build_ext --inplace
Run the tests with:
.. code-block:: shell
python -m unittest discover python/tests
Optional: Install stubs to enable auto completions and type checking from your
IDE:
.. code-block:: shell
python setup.py generate_stubs
C++ API
^^^^^^^
.. _cpp install:
Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start
by cloning MLX from `its GitHub repo
<https://github.com/ml-explore/mlx>`_:
.. code-block:: shell
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Create a build directory and run CMake and make:
.. code-block:: shell
mkdir -p build && cd build
cmake .. && make -j
Run tests with:
.. code-block:: shell
make test
Install with:
.. code-block:: shell
make install
Note that the built ``mlx.metallib`` file should be either at the same
directory as the executable statically linked to ``libmlx.a`` or the
preprocessor constant ``METAL_PATH`` should be defined at build time and it
should point to the path to the built metal library.
.. list-table:: Build Options
:widths: 25 8
:header-rows: 1
* - Option
- Default
* - MLX_BUILD_TESTS
- ON
* - MLX_BUILD_EXAMPLES
- OFF
* - MLX_BUILD_BENCHMARKS
- OFF
* - MLX_BUILD_METAL
- ON
* - MLX_BUILD_CPU
- ON
* - MLX_BUILD_PYTHON_BINDINGS
- OFF
* - MLX_METAL_DEBUG
- OFF
* - MLX_BUILD_SAFETENSORS
- ON
* - MLX_BUILD_GGUF
- ON
* - MLX_METAL_JIT
- OFF
.. note::
If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the
following environment variable before building
.. code-block:: shell
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
Further, you can use the following command to find out which
macOS SDK will be used
.. code-block:: shell
xcrun -sdk macosx --show-sdk-version
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
and ``BUILD_SHARED_LIBS=ON``.
The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and
GGUF, you can do:
.. code-block:: shell
cmake .. \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots.
Linux
^^^^^
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
For example on Ubuntu, run the following:
.. code-block:: shell
apt-get update -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
From here follow the instructions to install either the :ref:`Python <python
install>` or :ref:`C++ <cpp install>` APIs.
CUDA
^^^^
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
and the CUDA toolkit. For example on Ubuntu, run the following:
.. code-block:: shell
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
When building either the Python or C++ APIs make sure to pass the cmake flag
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
.. code-block:: shell
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run:
.. code-block:: shell
mkdir -p build && cd build
cmake .. -DMLX_BUILD_CUDA=ON && make -j
Troubleshooting
^^^^^^^^^^^^^^^
Metal not found
~~~~~~~~~~~~~~~
You see the following error when you try to build:
.. code-block:: shell
error: unable to find utility "metal", not a developer tool or in PATH
To fix this, first make sure you have Xcode installed:
.. code-block:: shell
xcode-select --install
Then set the active developer directory:
.. code-block:: shell
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
x86 Shell
~~~~~~~~~
.. _build shell:
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
Rosetta instead of natively.
To fix this, find the application in Finder (``/Applications`` for iTerm,
``/Applications/Utilities`` for Terminal), right-click, and click “Get Info”.
Uncheck “Open using Rosetta”, close the “Get Info” window, and restart your
terminal.
Verify the terminal is now running natively the following command:
.. code-block:: shell
$ uname -p
arm
Also check that cmake is using the correct architecture:
.. code-block:: shell
$ cmake --system-information | grep CMAKE_HOST_SYSTEM_PROCESSOR
CMAKE_HOST_SYSTEM_PROCESSOR "arm64"
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
but the build errors out with "Building for x86_64 on macOS is not supported."
wipe your build cache with ``rm -rf build/`` and try again.

View File

@@ -0,0 +1,28 @@
mlx.core.Device
===============
.. currentmodule:: mlx.core
.. autoclass:: Device
.. automethod:: __init__
.. rubric:: Methods
.. autosummary::
~Device.__init__
.. rubric:: Attributes
.. autosummary::
~Device.type

View File

@@ -0,0 +1,28 @@
mlx.core.Dtype
==============
.. currentmodule:: mlx.core
.. autoclass:: Dtype
.. automethod:: __init__
.. rubric:: Methods
.. autosummary::
~Dtype.__init__
.. rubric:: Attributes
.. autosummary::
~Dtype.size

View File

@@ -0,0 +1,29 @@
mlx.core.DtypeCategory
======================
.. currentmodule:: mlx.core
.. autoclass:: DtypeCategory
.. automethod:: __init__
.. rubric:: Attributes
.. autosummary::
~DtypeCategory.complexfloating
~DtypeCategory.floating
~DtypeCategory.inexact
~DtypeCategory.signedinteger
~DtypeCategory.unsignedinteger
~DtypeCategory.integer
~DtypeCategory.number
~DtypeCategory.generic

View File

@@ -0,0 +1,6 @@
mlx.core.abs
============
.. currentmodule:: mlx.core
.. autofunction:: abs

View File

@@ -0,0 +1,6 @@
mlx.core.add
============
.. currentmodule:: mlx.core
.. autofunction:: add

View File

@@ -0,0 +1,6 @@
mlx.core.addmm
==============
.. currentmodule:: mlx.core
.. autofunction:: addmm

View File

@@ -0,0 +1,6 @@
mlx.core.all
============
.. currentmodule:: mlx.core
.. autofunction:: all

View File

@@ -0,0 +1,6 @@
mlx.core.allclose
=================
.. currentmodule:: mlx.core
.. autofunction:: allclose

View File

@@ -0,0 +1,6 @@
mlx.core.any
============
.. currentmodule:: mlx.core
.. autofunction:: any

View File

@@ -0,0 +1,6 @@
mlx.core.arange
===============
.. currentmodule:: mlx.core
.. autofunction:: arange

View File

@@ -0,0 +1,6 @@
mlx.core.arccos
===============
.. currentmodule:: mlx.core
.. autofunction:: arccos

View File

@@ -0,0 +1,6 @@
mlx.core.arccosh
================
.. currentmodule:: mlx.core
.. autofunction:: arccosh

View File

@@ -0,0 +1,6 @@
mlx.core.arcsin
===============
.. currentmodule:: mlx.core
.. autofunction:: arcsin

View File

@@ -0,0 +1,6 @@
mlx.core.arcsinh
================
.. currentmodule:: mlx.core
.. autofunction:: arcsinh

View File

@@ -0,0 +1,6 @@
mlx.core.arctan
===============
.. currentmodule:: mlx.core
.. autofunction:: arctan

View File

@@ -0,0 +1,6 @@
mlx.core.arctan2
================
.. currentmodule:: mlx.core
.. autofunction:: arctan2

View File

@@ -0,0 +1,6 @@
mlx.core.arctanh
================
.. currentmodule:: mlx.core
.. autofunction:: arctanh

View File

@@ -0,0 +1,6 @@
mlx.core.argmax
===============
.. currentmodule:: mlx.core
.. autofunction:: argmax

View File

@@ -0,0 +1,6 @@
mlx.core.argmin
===============
.. currentmodule:: mlx.core
.. autofunction:: argmin

View File

@@ -0,0 +1,6 @@
mlx.core.argpartition
=====================
.. currentmodule:: mlx.core
.. autofunction:: argpartition

View File

@@ -0,0 +1,6 @@
mlx.core.argsort
================
.. currentmodule:: mlx.core
.. autofunction:: argsort

View File

@@ -0,0 +1,6 @@
mlx.core.array.T
================
.. currentmodule:: mlx.core
.. autoproperty:: array.T

View File

@@ -0,0 +1,6 @@
mlx.core.array.abs
==================
.. currentmodule:: mlx.core
.. automethod:: array.abs

View File

@@ -0,0 +1,6 @@
mlx.core.array.all
==================
.. currentmodule:: mlx.core
.. automethod:: array.all

View File

@@ -0,0 +1,6 @@
mlx.core.array.any
==================
.. currentmodule:: mlx.core
.. automethod:: array.any

View File

@@ -0,0 +1,6 @@
mlx.core.array.argmax
=====================
.. currentmodule:: mlx.core
.. automethod:: array.argmax

View File

@@ -0,0 +1,6 @@
mlx.core.array.argmin
=====================
.. currentmodule:: mlx.core
.. automethod:: array.argmin

View File

@@ -0,0 +1,6 @@
mlx.core.array.astype
=====================
.. currentmodule:: mlx.core
.. automethod:: array.astype

View File

@@ -0,0 +1,6 @@
mlx.core.array.at
=================
.. currentmodule:: mlx.core
.. autoproperty:: array.at

View File

@@ -0,0 +1,6 @@
mlx.core.array.conj
===================
.. currentmodule:: mlx.core
.. automethod:: array.conj

View File

@@ -0,0 +1,6 @@
mlx.core.array.cos
==================
.. currentmodule:: mlx.core
.. automethod:: array.cos

View File

@@ -0,0 +1,6 @@
mlx.core.array.cummax
=====================
.. currentmodule:: mlx.core
.. automethod:: array.cummax

View File

@@ -0,0 +1,6 @@
mlx.core.array.cummin
=====================
.. currentmodule:: mlx.core
.. automethod:: array.cummin

View File

@@ -0,0 +1,6 @@
mlx.core.array.cumprod
======================
.. currentmodule:: mlx.core
.. automethod:: array.cumprod

View File

@@ -0,0 +1,6 @@
mlx.core.array.cumsum
=====================
.. currentmodule:: mlx.core
.. automethod:: array.cumsum

View File

@@ -0,0 +1,6 @@
mlx.core.array.diag
===================
.. currentmodule:: mlx.core
.. automethod:: array.diag

View File

@@ -0,0 +1,6 @@
mlx.core.array.diagonal
=======================
.. currentmodule:: mlx.core
.. automethod:: array.diagonal

View File

@@ -0,0 +1,6 @@
mlx.core.array.dtype
====================
.. currentmodule:: mlx.core
.. autoproperty:: array.dtype

View File

@@ -0,0 +1,6 @@
mlx.core.array.exp
==================
.. currentmodule:: mlx.core
.. automethod:: array.exp

View File

@@ -0,0 +1,6 @@
mlx.core.array.flatten
======================
.. currentmodule:: mlx.core
.. automethod:: array.flatten

View File

@@ -0,0 +1,6 @@
mlx.core.array.imag
===================
.. currentmodule:: mlx.core
.. autoproperty:: array.imag

View File

@@ -0,0 +1,6 @@
mlx.core.array.item
===================
.. currentmodule:: mlx.core
.. automethod:: array.item

View File

@@ -0,0 +1,6 @@
mlx.core.array.itemsize
=======================
.. currentmodule:: mlx.core
.. autoproperty:: array.itemsize

View File

@@ -0,0 +1,6 @@
mlx.core.array.log
==================
.. currentmodule:: mlx.core
.. automethod:: array.log

View File

@@ -0,0 +1,6 @@
mlx.core.array.log10
====================
.. currentmodule:: mlx.core
.. automethod:: array.log10

View File

@@ -0,0 +1,6 @@
mlx.core.array.log1p
====================
.. currentmodule:: mlx.core
.. automethod:: array.log1p

View File

@@ -0,0 +1,6 @@
mlx.core.array.log2
===================
.. currentmodule:: mlx.core
.. automethod:: array.log2

View File

@@ -0,0 +1,6 @@
mlx.core.array.logcumsumexp
===========================
.. currentmodule:: mlx.core
.. automethod:: array.logcumsumexp

View File

@@ -0,0 +1,6 @@
mlx.core.array.logsumexp
========================
.. currentmodule:: mlx.core
.. automethod:: array.logsumexp

View File

@@ -0,0 +1,6 @@
mlx.core.array.max
==================
.. currentmodule:: mlx.core
.. automethod:: array.max

View File

@@ -0,0 +1,6 @@
mlx.core.array.mean
===================
.. currentmodule:: mlx.core
.. automethod:: array.mean

View File

@@ -0,0 +1,6 @@
mlx.core.array.min
==================
.. currentmodule:: mlx.core
.. automethod:: array.min

View File

@@ -0,0 +1,6 @@
mlx.core.array.moveaxis
=======================
.. currentmodule:: mlx.core
.. automethod:: array.moveaxis

View File

@@ -0,0 +1,6 @@
mlx.core.array.nbytes
=====================
.. currentmodule:: mlx.core
.. autoproperty:: array.nbytes

View File

@@ -0,0 +1,6 @@
mlx.core.array.ndim
===================
.. currentmodule:: mlx.core
.. autoproperty:: array.ndim

View File

@@ -0,0 +1,6 @@
mlx.core.array.prod
===================
.. currentmodule:: mlx.core
.. automethod:: array.prod

View File

@@ -0,0 +1,6 @@
mlx.core.array.real
===================
.. currentmodule:: mlx.core
.. autoproperty:: array.real

View File

@@ -0,0 +1,6 @@
mlx.core.array.reciprocal
=========================
.. currentmodule:: mlx.core
.. automethod:: array.reciprocal

View File

@@ -0,0 +1,6 @@
mlx.core.array.reshape
======================
.. currentmodule:: mlx.core
.. automethod:: array.reshape

View File

@@ -0,0 +1,6 @@
mlx.core.array.round
====================
.. currentmodule:: mlx.core
.. automethod:: array.round

View File

@@ -0,0 +1,6 @@
mlx.core.array.rsqrt
====================
.. currentmodule:: mlx.core
.. automethod:: array.rsqrt

View File

@@ -0,0 +1,81 @@
mlx.core.array
==============
.. currentmodule:: mlx.core
.. autoclass:: array
.. automethod:: __init__
.. rubric:: Methods
.. autosummary::
~array.__init__
~array.abs
~array.all
~array.any
~array.argmax
~array.argmin
~array.astype
~array.conj
~array.cos
~array.cummax
~array.cummin
~array.cumprod
~array.cumsum
~array.diag
~array.diagonal
~array.exp
~array.flatten
~array.item
~array.log
~array.log10
~array.log1p
~array.log2
~array.logcumsumexp
~array.logsumexp
~array.max
~array.mean
~array.min
~array.moveaxis
~array.prod
~array.reciprocal
~array.reshape
~array.round
~array.rsqrt
~array.sin
~array.split
~array.sqrt
~array.square
~array.squeeze
~array.std
~array.sum
~array.swapaxes
~array.tolist
~array.transpose
~array.var
~array.view
.. rubric:: Attributes
.. autosummary::
~array.T
~array.at
~array.dtype
~array.imag
~array.itemsize
~array.nbytes
~array.ndim
~array.real
~array.shape
~array.size

View File

@@ -0,0 +1,6 @@
mlx.core.array.shape
====================
.. currentmodule:: mlx.core
.. autoproperty:: array.shape

View File

@@ -0,0 +1,6 @@
mlx.core.array.sin
==================
.. currentmodule:: mlx.core
.. automethod:: array.sin

View File

@@ -0,0 +1,6 @@
mlx.core.array.size
===================
.. currentmodule:: mlx.core
.. autoproperty:: array.size

View File

@@ -0,0 +1,6 @@
mlx.core.array.split
====================
.. currentmodule:: mlx.core
.. automethod:: array.split

View File

@@ -0,0 +1,6 @@
mlx.core.array.sqrt
===================
.. currentmodule:: mlx.core
.. automethod:: array.sqrt

View File

@@ -0,0 +1,6 @@
mlx.core.array.square
=====================
.. currentmodule:: mlx.core
.. automethod:: array.square

View File

@@ -0,0 +1,6 @@
mlx.core.array.squeeze
======================
.. currentmodule:: mlx.core
.. automethod:: array.squeeze

View File

@@ -0,0 +1,6 @@
mlx.core.array.std
==================
.. currentmodule:: mlx.core
.. automethod:: array.std

View File

@@ -0,0 +1,6 @@
mlx.core.array.sum
==================
.. currentmodule:: mlx.core
.. automethod:: array.sum

View File

@@ -0,0 +1,6 @@
mlx.core.array.swapaxes
=======================
.. currentmodule:: mlx.core
.. automethod:: array.swapaxes

View File

@@ -0,0 +1,6 @@
mlx.core.array.tolist
=====================
.. currentmodule:: mlx.core
.. automethod:: array.tolist

View File

@@ -0,0 +1,6 @@
mlx.core.array.transpose
========================
.. currentmodule:: mlx.core
.. automethod:: array.transpose

View File

@@ -0,0 +1,6 @@
mlx.core.array.var
==================
.. currentmodule:: mlx.core
.. automethod:: array.var

View File

@@ -0,0 +1,6 @@
mlx.core.array.view
===================
.. currentmodule:: mlx.core
.. automethod:: array.view

View File

@@ -0,0 +1,6 @@
mlx.core.array\_equal
=====================
.. currentmodule:: mlx.core
.. autofunction:: array_equal

View File

@@ -0,0 +1,6 @@
mlx.core.as\_strided
====================
.. currentmodule:: mlx.core
.. autofunction:: as_strided

View File

@@ -0,0 +1,6 @@
mlx.core.async\_eval
====================
.. currentmodule:: mlx.core
.. autofunction:: async_eval

Some files were not shown because too many files have changed in this diff Show More