Compare commits

...

405 Commits

Author SHA1 Message Date
Awni Hannun
5d74921cf2 add sdpa with sinks 2025-08-31 10:59:50 -07:00
Awni Hannun
8ce49cd39e fix quantized vjp for mxfp4 (#2555) 2025-08-29 10:06:15 -07:00
Awni Hannun
9c68b50853 version bump (#2554) 2025-08-29 06:54:17 -07:00
Awni Hannun
111f1e71af Faster contiguous gather for indices in the first axis (#2552)
* faster contiguous gather for indices in the first axis

* work per thread > 1

* angelos suggestion for scales / biases
2025-08-28 21:26:30 -07:00
Awni Hannun
827003d568 fix METAL quantization in JIT (#2553) 2025-08-28 18:26:25 -07:00
Awni Hannun
d363a76aa4 Bump xcode in circle (#2551)
* bump xcode in circle

* bump xcode in circle

* bump xcode in circle
2025-08-28 13:13:34 -07:00
Awni Hannun
70560b6bd5 Add mode parameter for quantization (#2499)
* add mode parameter for quantization

* mxfp4 quantize/dequantize + start of optional biases

* mxfp4 works

* speedup

* cpu mxfp4

* fix

* fix test tol

* fix

* refactor

* add quant mode enum
2025-08-28 06:45:26 -07:00
Awni Hannun
7ef8a6f2d5 [CUDA] fix sort (#2550)
* [CUDA] fix sort

* fix test
2025-08-27 19:48:43 -07:00
Cheng
31c6f6e33f [CUDA] Use ConcurrentContext in concatenate_gpu (#2549) 2025-08-28 09:30:08 +09:00
Awni Hannun
584d48458e link with nccl (#2546) 2025-08-27 10:01:07 -07:00
Cheng
5cf984ca87 Separate cpu compilation cache by versions (#2548) 2025-08-27 11:25:15 +09:00
Cheng
a9bac3d9e5 Run CPP tests for CUDA build in CI (#2544) 2025-08-27 08:06:46 +09:00
Awni Hannun
5458d43247 add load with path tests (#2543) 2025-08-26 14:24:47 -07:00
Awni Hannun
a4dba65220 Enable cuda graph toggle (#2545)
* enable cuda graph toggle

* increase cache size
2025-08-26 12:50:38 -07:00
Awni Hannun
3dcb286baf Remove stream from average grads so it uses default (#2532)
* Remove stream from average grads so it uses default

* comment
2025-08-25 15:56:29 -07:00
Cheng
4822c3dbe9 [CUDA] Implement DynamicSlice/DynamicSliceUpdate (#2533)
* Move DynamicSlice to gpu/primitives

* Implement compute_dynamic_offset in CUDA
2025-08-26 07:31:39 +09:00
Awni Hannun
2ca75bb529 Remove nccl install in release (#2542) 2025-08-25 15:20:18 -07:00
Awni Hannun
db14e29a0b allow pathlib.Path to save/load functions (#2541) 2025-08-25 14:58:49 -07:00
Awni Hannun
d2f540f4e0 Use nccl header only when nccl is not present (#2539)
* use nccl header only when nccl is not present

* larger machine for cuda build
2025-08-25 14:17:25 -07:00
Cheng
333ffea273 [CUDA] Remove thrust in arange (#2535) 2025-08-24 16:22:36 +09:00
Cheng
f55b6f1f2f Enable COMPILE_WARNING_AS_ERROR for linux builds in CI (#2534) 2025-08-24 15:33:08 +09:00
Awni Hannun
30561229c7 Fix allocation bug in NCCL (#2530) 2025-08-22 14:39:43 -07:00
Awni Hannun
068a4612e9 nccl default for backend=any (#2528)
* nccl default for backend=any

* check num gpus + ensure row contiguous for all reduce

* comment
2025-08-22 12:24:27 -07:00
Andrey Portnoy
5722c147de [CUDA] Update calls to cudaMemAdvise and cudaGraphAddDependencies for CUDA 13 (#2525)
* [CUDA] Update cudaMemAdvise and cudaGraphAddDependencies for CUDA 13

These functions' signatures changed in CUDA 13, so we differentiate
between CUDA 13 and preceding releases at compile time.

* Mention NVIDIA in ACKNOWLEDGMENTS.md
2025-08-21 19:57:20 -07:00
Cheng
f6819a1f26 Fix warning 186-D from nvcc (#2527) 2025-08-22 10:29:55 +09:00
Awni Hannun
f93f87c802 nccl dep + default for cuda (#2526) 2025-08-21 17:57:49 -07:00
Anastasiia Filippova
9392fc3f88 NCCL backend (#2476) 2025-08-21 11:56:15 -07:00
Awni Hannun
e843c4d8d5 fix power (#2523) 2025-08-21 06:46:01 -07:00
Angelos Katharopoulos
0c5fc63a36 Fix docs omission (#2524) 2025-08-20 17:56:06 -07:00
Angelos Katharopoulos
e397177f6e Custom cuda kernel (#2517) 2025-08-20 17:20:22 -07:00
Cheng
f4c8888cbe [CUDA] Fix stride of singleton dims before passing to cuDNN (#2521) 2025-08-21 08:55:26 +09:00
Angelos Katharopoulos
25c1e03205 Fix overflow in large filter small channels (#2520) 2025-08-20 08:03:29 -07:00
russellizadi
512281781c Remove state return from function example in compile documentation (#2518) 2025-08-20 00:45:05 -07:00
Cheng
ac85ddfdb7 [CUDA] Add GEMM-based fallback convolution kernels (#2511)
* Add gemm_conv

* Add gemm_grouped_conv
2025-08-20 10:06:22 +09:00
Cheng
65d0d40232 Split cuDNN helpers into a separate header (#2491)
* Add RAII managed CudaGraph class

* Implement forward rms_norm with cuDNN

* Revert back to old rms norm kernel
2025-08-20 09:29:28 +09:00
Awni Hannun
cea9369610 fix lapack svd (#2515) 2025-08-18 15:07:59 -07:00
Awni Hannun
e7c6e1db82 no segfault with uninitialized array.at (#2514) 2025-08-18 08:33:38 -07:00
Awni Hannun
c5fcd5b61b fix custom kernel test (#2510) 2025-08-18 06:45:59 -07:00
Angelos Katharopoulos
1df9887998 Ensure no oob read in gemv_masked (#2508) 2025-08-17 08:42:33 -07:00
Angelos Katharopoulos
73f22d6226 Ensure small sort doesn't use indices if not argsort (#2506) 2025-08-17 08:42:20 -07:00
Cheng
c422050ca7 Update cuDNN Frontend to v1.14 (#2505) 2025-08-17 19:13:01 +09:00
Cheng
1ba18ff7d9 [CUDA] Fix conv grads with groups (#2495)
* Put reshape utils in one file

* [CUDA] Fix conv grads with groups

* Put the reshape utils in gpu/copy.h
2025-08-16 10:09:18 +09:00
Cheng
37b440faa8 Clean up code handling both std::vector and SmallVector (#2493) 2025-08-16 09:01:10 +09:00
Cheng
888b13ed63 Remove the hack around SmallVector in cpu compile (#2494) 2025-08-16 08:17:24 +09:00
Cheng
4abb218d21 The naive_conv_2d is no longer used (#2496) 2025-08-16 07:57:30 +09:00
Awni Hannun
6441c21a94 Faster general unary op (#2472)
* faster general unary op

* faster general ops + reorg

* fix + comment

* binary two

* copy general
2025-08-15 15:04:12 -07:00
Cheng
dfb5022eab Rename cu::Matmul to CublasGemm (#2488) 2025-08-13 09:37:40 +09:00
Daniel Yeh
ac207ce7aa make code blocks copyable (#2480)
Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
2025-08-12 12:29:02 -07:00
Abe Leininger
fce53b61d6 Fix reduce sum/prod overflow (#2477) 2025-08-12 00:05:33 -07:00
Angelos Katharopoulos
8ae4a76308 Use CMake <4.1 to avoid the nvpl error (#2489) 2025-08-12 00:03:42 -07:00
Cheng
7fde1b6a1e Fix logsumexp/softmax not fused for some cases (#2474) 2025-08-08 14:07:17 -07:00
Cheng
aa7b47481a [CUDA] Optimize set_mm_device_pointers for small ndim (#2473) 2025-08-08 15:23:30 +09:00
Awni Hannun
56be773610 version (#2470) 2025-08-07 00:36:04 -07:00
Jagrit Digani
a9bdd67baa Add CUDA sdpa vector (#2468) 2025-08-06 21:40:26 -07:00
Angelos Katharopoulos
f2adb5638d Fix typo in metal command encoder (#2471) 2025-08-06 16:58:23 -07:00
Luca Vivona
728d4db582 Support destination arg in tree flatten/unflatten (#2450) 2025-08-06 15:34:59 -07:00
Awni Hannun
db5c7efcf6 revert default cuda install (#2465)
* revert default cuda install

* revert default cuda install
2025-08-06 06:19:12 -07:00
Awni Hannun
7bb96e4249 fix cublas on h100 (#2466) 2025-08-06 06:18:58 -07:00
Awni Hannun
fa89f0b150 faster gather qmm sorted test (#2463) 2025-08-05 06:27:40 -07:00
Awni Hannun
ca973d1e83 fix install tags (#2464) 2025-08-04 20:01:23 -07:00
Cheng
828c5f1137 Use SmallVector for shapes and strides (#2454)
* Use SmallVector for shapes and strides

* Convert SmallVector to tuple
2025-08-05 09:41:03 +09:00
Gaétan Lepage
7d86a5c108 Feat: add USE_SYSTEM_FMT CMake option (#2219) 2025-08-04 16:36:11 -07:00
Awni Hannun
0b807893a7 fix wraps compile (#2461) 2025-08-04 16:14:18 -07:00
Awni Hannun
6ad0889c8a default install cuda on linux (#2462) 2025-08-04 15:33:05 -07:00
Zamderax
737dd6d1ac Add missing <algorithm> header to jit_compiler.cpp (#2460)
Fixes compilation error on Linux where std::find_if is used on line 121
but the <algorithm> header was not included. While this might work on
some platforms due to transitive includes, it's not guaranteed by the
C++ standard.

Resolves issue #2459
2025-08-04 14:00:46 -07:00
Cheng
aaf78f4c6b Use LRU cache for cuda graph (#2448)
* Use LRU cache for cuda graph

* Remove unused destructor
2025-08-02 21:28:57 +09:00
Angelos Katharopoulos
8831064493 Fix arctan2 grads (#2453) 2025-08-01 21:06:04 -07:00
Angelos Katharopoulos
be9bc96da4 [CUDA] Matmul utils initial commit (#2441) 2025-08-01 14:22:25 -07:00
Angelos Katharopoulos
86258f292f [CUDA] Vectorize generated kernels (#2444) 2025-07-31 18:18:57 -07:00
Cheng
b26d88591c [CUDA] Save primitive inputs faster (#2449)
* Add more nvtx loggings

* [CUDA] Saving primitive inputs faster

* Remove unneeded check
2025-08-01 10:16:06 +09:00
Cheng
86c6a15571 [CUDA] Backward convolution (#2431) 2025-08-01 09:54:05 +09:00
junpeiz
8b25ce62d5 Add tests for export including control flow models and quantized models (#2430)
* Add tests for export, including control flow export and quantized model export.

* Skip quantization related test for CUDA backend.
2025-07-31 11:06:26 -07:00
Awni Hannun
da5912e4f2 fix custom metal extension (#2446) 2025-07-31 06:25:36 -07:00
Cheng
daafee676f Fix wrong graph key when using concurrent context (#2447) 2025-07-31 06:01:05 -07:00
Awni Hannun
d32519c8ee fix gemv regression (#2445) 2025-07-30 14:23:01 -07:00
Awni Hannun
b405591249 fix circular reference (#2443) 2025-07-30 09:37:44 -07:00
Angelos Katharopoulos
3bf81ed1bd [CUDA] Quantized refactoring (#2442) 2025-07-30 08:27:20 -07:00
Cheng
2204182bba Make CI faster (#2440) 2025-07-30 02:26:36 -07:00
Cheng
3628e5d497 Use load_vector in arg_reduce (#2439) 2025-07-30 17:40:26 +09:00
Cheng
a0ae49d397 Move arange to its own file (#2438) 2025-07-30 13:05:51 +09:00
Cheng
254476718b Remove the kernel arg from get_launch_args (#2437) 2025-07-30 11:43:02 +09:00
Awni Hannun
3adba92ebe Cuda faster softmax (#2435)
* faster softmax and logsumexp

* faster softmax and logsumexp

* format
2025-07-29 17:18:12 -07:00
Awni Hannun
ef631d63af faster rms norm (#2433) 2025-07-29 13:12:00 -07:00
Cheng
970dbe8e25 Use ccache in CI (#2414)
* Detect ccache

* Use ccache in CI

* Separate cache for different images

* Test both 12.2 and 12.9 for PRs
2025-07-29 08:43:22 +09:00
Awni Hannun
641be9463b Add more CUDA architectures for PyPi package (#2427)
* add cuda sm 90

* add more archs
2025-07-28 12:35:15 -07:00
Awni Hannun
ab0e608862 [CUDA] More sizes for gemv (#2429)
* route more to gemv

* route more sizes to custom gemv
2025-07-28 12:35:01 -07:00
Awni Hannun
1588659062 no occupancy query for launch params (#2426) 2025-07-28 09:09:41 -07:00
Awni Hannun
b9e88fb976 [CUDA] Fix segfault on exit (#2424)
* fix cuda segfault on exit

* comment
2025-07-27 08:08:13 -07:00
Awni Hannun
4ad53414dd fix cuda pypi package (#2423)
* fix cuda pypi package

* patch bump
2025-07-25 15:20:29 -07:00
Awni Hannun
d1165b215e version (#2420) 2025-07-25 13:29:28 -07:00
Awni Hannun
dcb8319f3d update install docs and requirements (#2419) 2025-07-25 12:13:19 -07:00
Awni Hannun
5597fa089c Fix qvm splitk (#2415) 2025-07-25 11:50:24 -07:00
Awni Hannun
9acec364c2 [CUDA] Always use batched matmul (#2404)
* cuda batched mm

* addmm as well

* comment
2025-07-24 20:46:02 -07:00
Skonor
7d9d6ef456 docs: fix adam and adamw eps placement (#2416)
Co-authored-by: Mikhail Gorbunov <m_gorbunov@apple.com>
2025-07-24 16:40:45 -07:00
Cheng
6f5874a2f2 [CUDA] Initial implementation of Convolution with cuDNN (#2385)
* Link with cuDNN

* Initial implementation

* Remove backend apis

* Fix recording cudnn conv

* More unused backend apis

* Fix C++ conv tests

* include cudnn as python dep

* Install libcudnn9-dev-cuda-12 in CI

* cudnn only accepts contiguous inputs

* Switch to backend apis

* Plan needs to be kept alive

* Turn off tf32

* Add cache

* Test the native cuda graph api

* Set cudnn stream before execution

* Make LRUCache more like a normal container

* Do error check for cublas handle

* Zero-initilizing array

* Use tf32 for conv

* Skip TestConv.test_torch_conv_2D test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-07-25 08:12:10 +09:00
Awni Hannun
70dc336785 Test on cuda 12.2 and 12.9 (#2413) 2025-07-24 06:06:15 -07:00
Awni Hannun
4e504039f5 [Metal] Release metal events (#2412)
* release metal events

* fix

* fix
2025-07-23 19:53:42 -07:00
Awni Hannun
d1f4d291e8 Fix uv install and add dev release (#2411)
* fix uv install and add dev release

* fix docstring

* pin cuda deps

* cuda release on cpu-only machine
2025-07-23 16:54:19 -07:00
Awni Hannun
e1840853ce full row mask in sdpa consistently gives nan (#2406) 2025-07-23 16:37:03 -07:00
Cheng
0f5ce173da [CUDA] --compress-mode requires CUDA 12.8 (#2407) 2025-07-23 06:11:11 -07:00
Cheng
588854195f Remove unused code in Convolution::vjp (#2408) 2025-07-23 06:11:00 -07:00
Fangjun Kuang
28d068bce6 Fix an error in the comment for mx.dequantize (#2409) 2025-07-23 06:10:50 -07:00
Awni Hannun
d107d8d495 add cuda gemv (#2400) 2025-07-22 08:24:13 -07:00
Awni Hannun
1e496ddb82 [CUDA] Simplify allocator (#2392)
* simplify allocator and fixe race with small pool

* Don't use shared event in worker

* use cuda buffer in small pool

* comment

* comment
2025-07-22 08:24:01 -07:00
Awni Hannun
74eccbf3fa use size option in binary (#2399) 2025-07-22 07:00:53 -07:00
Awni Hannun
08638223ca Fix including stubs in wheel (#2398)
* fix including stubs in wheel

* fix bool_
2025-07-22 06:30:17 -07:00
Cheng
56cc858af9 Add contiguous_copy_cpu util for copying array (#2397) 2025-07-21 07:30:35 -07:00
Cheng
f55c4ed1d6 Remove thrust iterators (#2396) 2025-07-21 07:30:27 -07:00
Awni Hannun
93d70419e7 [CUDA] speedup handling scalars (#2389)
* speedup scalars in cuda

* comment
2025-07-18 21:47:31 -07:00
Awni Hannun
63f663d9c6 fix cuda manylinux version to match others (#2388) 2025-07-18 21:02:16 -07:00
Awni Hannun
84b4d96efa fix release build + patch bump (#2387) 2025-07-18 14:47:37 -07:00
Awni Hannun
aec67f2fa6 patch bump (#2386) 2025-07-18 12:25:48 -07:00
Gökdeniz Gülmez
deee214a95 Adding support for the Muon Optimizer (#1914)
* initial commit with workong optmimizer

* update ACKNOWLEDGMENTS.md

* nits and adding it to test

* nits

* G.astype(mx.bfloat16) to G.astype(G.dtype)

* G.ndim >= 2 to assert G.ndim == 2

* remove coments

* replace with  mx.addmm

* remove comments

* format

* nits

* match muon

* fix addmm

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-07-18 12:25:28 -07:00
Cheng
45adec102c Add contiguous_copy_gpu util for copying array (#2379) 2025-07-18 06:44:25 -07:00
Cheng
31fc530c76 [CUDA] Add more ways finding CCCL headers in JIT (#2382) 2025-07-17 15:25:34 -07:00
Awni Hannun
fbb3f65a1a fix resource leaks in matmul and graph (#2383) 2025-07-17 06:50:15 -07:00
Angelos Katharopoulos
6b1b8ea91b [CUDA] Add work per thread to compile (#2368) 2025-07-17 06:47:52 -07:00
Awni Hannun
b2273733ea Test with CUDA 12.2 (#2375)
* Test with CUDA 12.0

* try older image

* fix cpu sort
2025-07-16 13:00:37 -07:00
Awni Hannun
f409b229a4 fix ring distributed test (#2380) 2025-07-16 11:25:24 -07:00
Cheng
30571e2326 Rename the copy util in cpu/copy.h to copy_cpu (#2378) 2025-07-16 07:34:24 -07:00
Awni Hannun
d7734edd9f fix complex reduce + nan propagation in min and max (#2377) 2025-07-15 18:19:47 -07:00
Awni Hannun
2ba69bc8fa lower memory uniform sampling (#2361)
* lower memory uniform

* use fp32

* fix
2025-07-15 14:22:07 -07:00
Cheng
cb349a291c [CUDA] Use cuda::std::complex in place of cuComplex (#2372) 2025-07-15 00:36:13 -07:00
Awni Hannun
f0a0b077a0 Install linux with mlx[cuda] and mlx[cpu] (#2356)
* install linux with mlx[cuda] and mlx[cpu]

* temp for testing

* cleanup circle, fix cuda repair

* update circle

* update circle

* decouple python bindings from core libraries
2025-07-14 17:17:33 -07:00
Awni Hannun
49114f28ab fix flaky test (#2371) 2025-07-14 17:16:18 -07:00
Awni Hannun
e7d2ebadd2 [CUDA] Affine quantize (#2354)
* affine quantize and dequantize kernels

* format

* fix

* format
2025-07-14 15:45:44 -07:00
Awni Hannun
e569803d7c update linux build (#2370) 2025-07-14 15:13:56 -07:00
Cheng
d34f887abc Add Primitive::name and remove Primitive::print (#2365) 2025-07-14 14:06:35 -07:00
Angelos Katharopoulos
5201df5030 Fix imag() vjp (#2367) 2025-07-14 13:11:16 -07:00
Cheng
2d3c26c565 [CUDA] Do not put kernels in annoymous namespace (#2362) 2025-07-12 14:24:45 -07:00
Cheng
6325f60d52 [CUDA] Bundle CCCL for JIT compilation (#2357)
* Ship CCCL for JIT compilation

* Remove cexpf
2025-07-11 18:45:37 -07:00
Awni Hannun
42cc9cfbc7 fix copy dispatch (#2360) 2025-07-11 10:59:35 -07:00
Cheng
8347575ba1 [CUDA] Implement Scan kernel (#2347)
* Contiguous scan

* Strided scan

* Enable tests

* Fix failing logaddexp test

* Use cexpf in Metal
2025-07-10 16:54:12 -07:00
Angelos Katharopoulos
b6eec20260 Fix edge check in qmm_n QuantizedLoader (#2355) 2025-07-10 16:28:50 -07:00
Angelos Katharopoulos
0eb035b4b1 Fix type promotion in Adam with bias correction (#2350) 2025-07-10 11:14:42 -07:00
Cheng
afb9817599 [CUDA] Put version in ptx cache dir path (#2352) 2025-07-10 07:24:21 -07:00
Cheng
8fb3e7a26c [CUDA] Set current device before cudaGraphLaunch (#2351) 2025-07-10 07:24:02 -07:00
jhavukainen
8c7bc30ce4 Align mlx::core::min op nan propagation with NumPy (#2346) 2025-07-10 06:20:43 -07:00
Cheng
85873cb162 [CUDA] Do vectorized store/load in contiguous elementwise ops (#2342)
* Do vectorized store/load in unary ops

* Do vectorized store/load in binary_two ops

* Do vectorized store/load in copy ops

* Do vectorized store/load in ternary ops

* Use int32_t for IdxT

* binary => binary_two in binary_two.cu

* Fix tests on large arrays

* Use uint as index type

* Contig uses uint as index and non-contig uses int
2025-07-09 18:48:43 -07:00
Awni Hannun
e14ee12491 add zero for argsort vjp (#2345) 2025-07-09 14:37:14 -07:00
jhavukainen
8b9a3f3cea Align mlx::core::max op nan propagation with NumPy (#2339)
* Make max op NaN propagation rules align with numpy

* Adding benchmarks and testing for max op nanpropagation

* Pre-commit formatting

* Fix max complex64 nan propagation and add test

* Improve the cpp unittest

* Only check nans on non-integral types in simd_reduce_impl.

* Cleanup using namespace alias

* Add cpu Max nanpropagation. Fix a small fib in cpu max dispatch data types for int8/int16.

* Make the max nanpropagation test more meaningful for integer types

* Remove tuple unpacking syntax to comply with earlier python versions. Add cuda skip to nanpropagation tests, fix cuda implementation in a separate PR.
2025-07-09 11:26:27 -07:00
Awni Hannun
fb4e8b896b patch bump (#2343) 2025-07-08 14:26:07 -07:00
Cheng
2ca533b279 Fix compilation with CUDA 11 (#2331) 2025-07-07 20:00:43 -07:00
Angelos Katharopoulos
4a9b29a875 MoE backward improvements (#2335) 2025-07-07 17:59:53 -07:00
Awni Hannun
a4fcc893cd auto build linux release (#2341) 2025-07-07 09:29:23 -07:00
Cheng
9d10239af7 [CUDA] Do vectorized store/load in binary ops (#2330) 2025-07-07 08:44:14 -07:00
Cheng
19facd4b20 Build with all cpu cores by default (#2336) 2025-07-07 06:06:45 -07:00
Angelos Katharopoulos
f5299f72cd Fix layernorm race condition (#2340) 2025-07-07 06:06:01 -07:00
Cheng
0e0d9ac522 [CUDA] Add MLX_CUDA_GRAPH_CACHE_SIZE env for setting graph cache size (#2329) 2025-07-05 08:33:29 -07:00
Awni Hannun
8917022deb fix graphs for older cuda (#2328) 2025-07-02 19:37:58 -07:00
Awni Hannun
ec0d5db67b [CUDA] Switch to CUDA graphs (#2317)
* cuda graph prototype

fix signal bug + start to add dependencies

capture more

capture more ops

remaining ops

fix reduce and rope deps

add concurrent context

try update, but not working

cosistent topology order

use node api

use node api directly to reduce overhead

fix bug

use kernels in unary

cache graph

format

fix synchronization

format

* comment
2025-07-02 15:59:13 -07:00
Cheng
e76e9b87f0 Fix compilation error from integral_constant (#2326) 2025-07-02 06:04:38 -07:00
Awni Hannun
cfb6a244ea allow parameters to be deleted (#2325) 2025-07-01 21:27:23 -07:00
Awni Hannun
58f3860306 patch bump (#2324) 2025-07-01 12:12:16 -07:00
Awni Hannun
dd4f53db63 use fp32 for testing, add more complex ops (#2322) 2025-07-01 07:30:00 -07:00
Angelos Katharopoulos
3d5e17e507 MLX_SWITCH macros to templates (#2320) 2025-07-01 01:33:44 -07:00
Awni Hannun
33bf1a244b Fix module update in strict mode (#2321)
* fix module update in strict mode

* allow GELU to be pickled
2025-06-29 11:12:29 -07:00
Angelos Katharopoulos
772f471ff2 [CUDA] Fix reductions (#2314) 2025-06-27 12:59:20 -07:00
Angelos Katharopoulos
2c11d10f8d Split broadcast so it is always fused in compile (#2318) 2025-06-26 22:08:18 -07:00
Angelos Katharopoulos
656ed7f780 Fix get 2d grid dims (#2316) 2025-06-25 13:03:09 -07:00
Awni Hannun
81bb9a2a9e Compile float64 functions on CPU (#2311) 2025-06-24 10:18:52 -07:00
Angelos Katharopoulos
5adf185f86 Fix update_modules() when providing a subset (#2308) 2025-06-20 17:19:46 -07:00
Awni Hannun
c9a9180584 Cuda perf tuning (#2307)
* perf tuning

* fix adding inputs arrays in matmul / srot

* format

* fix
2025-06-20 14:50:57 -07:00
Awni Hannun
76831ed83d Build CUDA release in Circle (#2306)
* cuda release

* add license
2025-06-19 15:26:36 -07:00
Angelos Katharopoulos
b3d7b85376 Make ptx cache settable by environment variable (#2304) 2025-06-17 23:55:56 -07:00
Awni Hannun
cad5c0241c [CUDA] synch properly waits for all tasks to finish and clear (#2303)
* cuda synch properly waits for all tasks to finish and clear

* fix copy
2025-06-17 12:03:25 -07:00
Awni Hannun
b8022c578a divmod, partition, sort fixes (#2302) 2025-06-16 18:49:32 -07:00
Awni Hannun
bc53f8293f Cuda bug fixes 2 (#2298)
* more bug fixes

* more bug fixes

* format
2025-06-16 13:14:46 -07:00
Awni Hannun
c552ff2451 [CUDA] Fix back-end bugs and enable corresponding tests (#2296)
* Fix some cuda back-end bugs and enable corresponding tests

* more fixes

* enable more tests

* format
2025-06-16 08:45:40 -07:00
Awni Hannun
4fda5fbdf9 add python testing for cuda with ability to skip list of tests (#2295) 2025-06-15 10:56:48 -07:00
Angelos Katharopoulos
580776559b RoPE for CUDA (#2293)
* First working CUDA rope

* Fix random
2025-06-15 06:08:07 -07:00
Awni Hannun
a14aaa7c9d Fix cuda arg reduce (#2291) 2025-06-14 17:54:00 -07:00
Awni Hannun
a6d780154f fix cuda gemm for bf16 (#2288) 2025-06-13 22:10:46 -07:00
Awni Hannun
6871e2eeb7 fix cuda jit (#2287) 2025-06-13 19:21:46 -07:00
Awni Hannun
8402a2acf4 Fix complex power and print (#2286)
* fix complex power and print

* fix complex matmul shape
2025-06-13 11:13:00 -07:00
Jagrit Digani
fddb6933e1 Collection of refactors (#2274)
* Refactor gemv into a function

* Refactor splitk step 1

* Refactor split k axpby

* Rearrange steel_gemm_regular

* Redirect steel_gemm_regular

* Add axpby routing to steel_matmul_regular

* Refactor AddMM step 1

* Redirect steel_gemm

* Update addmm

* Comments and format

* Some cleanup

* Add architecture gen to device

* Update no copy condition in normalization to account for axis size 1
2025-06-13 10:44:56 -07:00
Cheng
c8b4787e4e CUDA backend: indexing ops (#2277) 2025-06-12 21:44:19 -07:00
Awni Hannun
2188199ff8 [CUDA] ternary with select op (#2283)
* cuda ternary with select op

* comment + fix

* fix
2025-06-12 20:24:43 -07:00
Awni Hannun
aa07429bad Fix cuda build (#2284) 2025-06-12 17:48:05 -07:00
Awni Hannun
918761a25a [CUDA] RMSNorm and VJP (#2280)
* rms norm start

* nit
2025-06-12 17:09:49 -07:00
Cheng
a4fc671d3e CUDA backend: compile (#2276)
* CUDA backend: compile

* Rename kernels/ to device/
2025-06-12 17:08:39 -07:00
Awni Hannun
f5f65ef48c Make sliceUpdate general (#2282)
* Make sliceUpdate general

* fix
2025-06-12 16:48:54 -07:00
Cheng
c2dd81a8aa Fix warnings from latest CUDA toolkit (#2275) 2025-06-12 06:03:01 -07:00
Cheng
d7e680ffe4 CUDA backend: layernorm (#2271) 2025-06-11 15:48:32 -07:00
Cheng
c371baf53a CUDA backend: softmax (#2272) 2025-06-11 13:55:22 -07:00
Cheng
ccf78f566c CUDA backend: argreduce (#2270) 2025-06-11 13:26:17 -07:00
Cheng
c9fa68664a CUDA backend: reduce (#2269) 2025-06-11 11:22:25 -07:00
Awni Hannun
c35f4d089a start cuda circle config (#2256)
* rebase

* fix metal kernel linking issue on cuda

* start cuda circle config
2025-06-10 21:19:47 -07:00
Angelos Katharopoulos
8590c0941e Add load_safe to the general conv loaders (#2258) 2025-06-10 20:58:16 -07:00
Cheng
095163b8d1 Fix building cpp benchmarks on Linux (#2268) 2025-06-10 17:10:24 -07:00
Cheng
99c33d011d rebase + nit (#2260)
Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-10 10:51:51 -07:00
Awni Hannun
62fecf3e13 fix conv export (#2265) 2025-06-10 09:34:01 -07:00
Cheng
7c4eb5d03e CUDA backend: random (#2261) 2025-06-10 08:59:56 -07:00
Cheng
bae9a6b404 CUDA backend: sort (#2262)
Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-10 08:59:47 -07:00
Christopher Fleetwood
004c1d8ef2 Report number of missing parameters (#2264)
* chore: inform

* chore: format

---------

Co-authored-by: FL33TW00D <FL33TW00D@users.noreply.github.com>
2025-06-10 06:37:50 -07:00
Cheng
7ebb2e0193 CUDA backend: binary ops (#2259) 2025-06-10 06:37:40 -07:00
Awni Hannun
9ce77798b1 fix export to work with gather/scatter axis (#2263) 2025-06-09 20:37:27 -07:00
Cheng
f8bad60609 CUDA backend: unary ops (#2158) 2025-06-09 06:45:08 -07:00
Emmanuel Ferdman
5866b3857b Refactor the lu test (#2250)
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-06-07 06:12:08 -07:00
Awni Hannun
1ca616844b Fix unintuitive metal kernel caching (#2242)
* Fix unintuitive metal kernel caching

* alternative solution
2025-06-06 20:08:15 -07:00
Angelos Katharopoulos
2e8cf0b450 Change layernorms to two pass algorithm (#2246) 2025-06-06 13:34:56 -07:00
Cheng
24f89173d1 CUDA backend: matmul (#2241) 2025-06-06 12:24:04 -07:00
Awni Hannun
c6a20b427a Improve metal elementwise kernels (#2247)
* improve metal elementwise kernels

* compile and copy

* fix jit
2025-06-06 11:37:40 -07:00
Awni Hannun
a5ac9244c4 fix linux linking error (#2248) 2025-06-06 10:41:51 -07:00
Awni Hannun
c763fe1be0 default strict mode for module update and update_modules (#2239) 2025-06-05 15:27:02 -07:00
Cheng
52dc8c8cd5 Add profiler annotations in common primitives for CUDA backend (#2244) 2025-06-04 19:55:12 -07:00
Angelos Katharopoulos
aede70e81d Perf regression fix (#2243) 2025-06-03 17:55:12 -07:00
Cheng
85a8beb5e4 Avoid atomic updates across CPU/GPU in CUDA event (#2231) 2025-06-03 16:49:06 -07:00
Cheng
0bb89e9e5f Share more common code in Compiled (#2240)
* Share more common code in Compiled

* Remove build_lib_name
2025-06-03 16:48:50 -07:00
Cheng
5685ceb3c7 Avoid invoking allocator::malloc when creating CUDA event (#2232) 2025-06-03 16:48:40 -07:00
Suryash Malviya
0408ba0a76 Optimizing Complex Matrix Multiplication using Karatsuba’s Algorithm (#2220)
* Implementing Complex Matmul using Karatsuba Algorithm

* Implemented Karatsuba's Algorithm for complex matmul and pre-commit them

* fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-02 15:58:46 -07:00
Awni Hannun
cbad6c3093 version (#2237) 2025-06-02 15:58:33 -07:00
Cheng
1b021f6984 Fast primitives decide when to use the fallback (#2216) 2025-06-02 13:26:37 -07:00
Cheng
95b7551d65 Do not check event.is_signaled() in eval_impl (#2230) 2025-06-02 13:23:34 -07:00
Cheng
db5a7c6192 Add memory cache to CUDA backend (#2221)
* Move BufferCache out of allocator

* Add memory cache to cuda backend allocator

* Simplify BufferCache assuming buf can not be null
2025-05-30 12:12:54 -07:00
Awni Hannun
6ef2f67e7f 5bit quants (#2226)
* 5bit quants

* 5bit quants
2025-05-30 12:12:10 -07:00
Cheng
f76ee1ffd2 Move some dims utils to common (#2223) 2025-05-29 06:48:30 -07:00
Cheng
54a71f270a Remove unused defines (#2217) 2025-05-23 06:14:58 -07:00
Awni Hannun
55b4062dd8 copyright in docs (#2214) 2025-05-21 17:13:04 -07:00
Cheng
79071bfba4 Fix out-of-bounds default value in logsumexp/softmax (#2213) 2025-05-21 07:25:16 -07:00
Cheng
7774b87cbd Remove redundant simd_sum in logsumexp (#2210) 2025-05-21 07:25:03 -07:00
Cheng
35c87741cf Build for compute capability 70 instead of 75 (#2209) 2025-05-20 19:42:48 -07:00
Jack Wind
4cbe605214 Feat: Allow per-target Metal debug flags (#2201)
* feat: allow per-target Metal debug flags

* formatting fix
2025-05-20 10:22:26 -07:00
Clement Liaw
ab8883dd55 include mlx::core::version() symbols in the mlx static library (#2207) 2025-05-20 07:39:11 -07:00
Awni Hannun
eebe73001a fix large arg reduce (#2206) 2025-05-19 13:10:44 -07:00
Angelos Katharopoulos
0359bf02c9 Nearest upsample (#2202) 2025-05-19 11:23:38 -07:00
Cheng
237f9e58a8 Fix BEFORE keyword in target_include_directories (#2204) 2025-05-19 06:10:44 -07:00
Awni Hannun
8576e6fe36 fix conv2d bug + faster conv 1d (#2195)
* fix conv2d bug + faster conv 1d

* revert sort + flaky test
2025-05-18 06:05:11 -07:00
Angelos Katharopoulos
0654543dcc Add complex eigh (#2191) 2025-05-18 00:18:43 -07:00
Awni Hannun
48ef3e74e2 reduce vjp for all and any (#2193) 2025-05-16 08:38:49 -07:00
Cheng
7d4b378952 Include cuda_bf16.h for bfloat16 overloads (#2192)
* Include cuda_bf16.h for bfloat16 overloads

* Add NO_GPU_MULTI(Eig) in cuda backend
2025-05-16 06:44:42 -07:00
Jack Wind
7ff5c41e06 Add set_threadgroup_memory_length to CommandEncoder (#2183) 2025-05-16 00:28:03 -07:00
Awni Hannun
602f43e3d1 fix conv grad (#2187) 2025-05-15 19:20:36 -07:00
Awni Hannun
a2cadb8218 real and imag properties (#2189) 2025-05-15 18:17:50 -07:00
Awni Hannun
c1eb9d05d9 non-symmetric eig and eigh (#2188) 2025-05-15 13:01:44 -07:00
Angelos Katharopoulos
cf6c939e86 Fix some complex vjps (#2178) 2025-05-14 23:37:12 -07:00
Angelos Katharopoulos
130df35e1b Add random normal distribution for complex numbers (#2182) 2025-05-13 22:43:45 -07:00
Cheng
0751263dec Fix typo in row_reduce_small (#2179) 2025-05-13 20:19:54 -07:00
Cheng
eca2f3eb97 Add remove_index utility (#2173) 2025-05-13 17:09:56 -07:00
Angelos Katharopoulos
3aa9cf3f9e Fix put_along_axis for empty arrays (#2181) 2025-05-13 14:27:53 -07:00
Awni Hannun
8f3d208dce Close a couple edge case bugs: hadamard and addmm on empty inputs (#2177)
* handle hadamard and addmm on empty inputs

* fix
2025-05-12 10:48:57 -07:00
Ivan Fioravanti
caaa3f1f8c Small typos in mx.metal deprecations (#2176) 2025-05-11 06:03:47 -07:00
Awni Hannun
659a51919f patch bump (#2162) 2025-05-09 14:35:14 -07:00
Awni Hannun
6661387066 Fix fft for integer overflow (#2161) 2025-05-09 14:25:12 -07:00
ATurker
a7fae8a176 fix: conv_general differences between gpu, cpu (#2070)
* fix general_conv padding

* fix bugs

* add test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-05-09 10:26:52 -07:00
Cheng
0cae0bdac8 CUDA backend: backbone (#2075) 2025-05-06 21:26:46 -07:00
Awni Hannun
5a1a5d5ed1 fix input coherent kernel launch (#2153) 2025-05-05 17:30:50 -07:00
Cheng
1683975acf Move common gpu primitives to backend/gpu (#2145) 2025-05-05 13:45:29 -07:00
Awni Hannun
af705590ac fix batched vector sdpa (#2152) 2025-05-05 13:13:03 -07:00
Awni Hannun
825124af8f fix bw for elementwise ops (#2151)
* fix bw for elementwise ops

* add compile

* fix

* fix

* fix

* fix
2025-05-05 06:15:04 -07:00
Awni Hannun
9c5e7da507 fix compile merging (#2150) 2025-05-02 15:08:50 -07:00
Angelos Katharopoulos
481349495b GPU Hadamard for large N (#1879) 2025-05-01 17:19:17 -07:00
Awni Hannun
9daa6b003f fix shapeless export (#2148) 2025-05-01 15:02:02 -07:00
Angelos Katharopoulos
a3a632d567 Fix the launcher when ran locally (#2147) 2025-05-01 12:56:09 -07:00
Awni Hannun
e496c5a4b4 fix integer overflow in qmm (#2143) 2025-04-30 09:28:56 -07:00
Cheng
ea890d8710 Remove metal-only tests (#2139) 2025-04-30 09:08:39 -07:00
Awni Hannun
aa5d84f102 Allow quant layer to be unfrozen (#2142) 2025-04-30 09:08:29 -07:00
Awni Hannun
f1606486d2 Generalize gpu backend (#2138)
* generalize gpu backend

* fix no_gpu build

* fix no_gpu build

* generalize gpu backend
2025-04-30 09:08:17 -07:00
Cheng
87720a8908 Fix building with uv (#2141) 2025-04-30 06:04:07 -07:00
Aashiq Dheeraj
bb6565ef14 add fftshift and ifftshift fft helpers (#2135)
* add fftshift and ifftshift fft helpers

* address comments

* axes have to be iterable

* fix fp error in roll + add test

---------

Co-authored-by: Aashiq Dheeraj <aashiq@aashiq-mbp-m4.local>
2025-04-29 22:13:45 -07:00
Awni Hannun
7bb063bcb3 Enable vjp for quantized scale and bias (#2129)
* Enable vjp for quantized scale and bias

* higher tol
2025-04-29 13:03:09 -07:00
Alex Chi Z.
b36dd472bb return library if it is successfully loaded (#2131) 2025-04-29 07:30:36 -07:00
hdeng-apple
167b759a38 Fix typos (#2136) 2025-04-29 07:26:05 -07:00
charan-003
99b9868859 Clarify dimension notation in conv1d, conv2d, and conv3d docstrings (#2123)
* Clarify dimension notation in conv1d, conv2d, and conv3d docstrings

* Updating transposed convs in conv1d, conv2d, and conv3d

---------

Co-authored-by: Sai Charan Arvapally <saicharan@Sais-MacBook-Pro.local>
2025-04-25 12:18:30 -07:00
1ndig0
6b2d5448f2 Fix the error message in mx.right_shift and mx.left_shift (#2121)
* update right_shift and lef_shift

* simplify

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-04-25 09:14:28 -07:00
Awni Hannun
eaf709b83e patch (#2119) 2025-04-24 16:11:07 -07:00
Angelos Katharopoulos
f0e70afff0 Fix swift pm load (#2117) 2025-04-24 10:58:29 -07:00
hdeng-apple
86984cad68 Remove static initializers (#2059)
* Remove static initializers in device.cpp, load.cpp, pocketfft.h

* Remove static initializer InTracing::trace_stack

* Remove static initializer of CompilerCache cache

* Revert changes in pocketfft.h

* Remove duplicate private section of thread_pool()
2025-04-24 06:14:49 -07:00
Awni Hannun
fbc89e3ced fix pinv (#2110) 2025-04-23 13:08:28 -07:00
hdeng-apple
38c1e720c2 Search mlx.metallib in macOS framework "Resources" dir (#2061)
---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-04-23 09:53:13 -07:00
Param Thakkar
600e87e03c Added output_padding parameters in conv_transpose (#2092) 2025-04-23 09:26:33 -07:00
Hyunsung Lee
3836445241 Add broadcast_shapes in python API (#2091) 2025-04-22 18:57:39 -07:00
Yury Popov
1d2c9d6a07 Complex scan (#2094) 2025-04-22 18:56:28 -07:00
Awni Hannun
e8ac6bd2f5 irfft throws instead of segfaults on scalars (#2109) 2025-04-22 10:25:55 -07:00
Awni Hannun
fdadc4f22c Add more complex unary ops (#2101) 2025-04-21 13:04:54 -07:00
Awni Hannun
79b527f45f conv vmap (#2102) 2025-04-21 13:04:39 -07:00
Awni Hannun
dc4eada7f0 Use unordered map for kwargs in export/import (#2087)
* use unordered map for kwargs in export/import

* comment
2025-04-21 07:17:22 -07:00
Cheng
70ebc3b598 Return const ref in array::data_shared_ptr (#2100) 2025-04-21 07:17:09 -07:00
Cheng
b13f2aed16 Introduce macros for dispatching dynamic dtypes as static types (#2073) 2025-04-19 06:16:30 -07:00
Param Thakkar
5f04c0f818 Fixed shift operations issue (#2080)
* Fixed shift operations issue

* Added tests and fixes

* Fixed loop syntax error

* Added tests for bool

* Fixed typo
2025-04-18 14:28:33 -07:00
Awni Hannun
55935ccae7 fix py gc edge case (#2079) 2025-04-18 12:46:53 -07:00
Awni Hannun
b529515eb1 minor bump (#2081) 2025-04-17 14:57:11 -07:00
Angelos Katharopoulos
3cde719eb7 Route to gather qmm only for many tokens per expert (#2082) 2025-04-17 14:53:08 -07:00
Angelos Katharopoulos
5de6d94a90 Gather qmm batched kernel and refactoring of quantized (#2078) 2025-04-17 13:53:11 -07:00
Angelos Katharopoulos
99eefd2ec0 Gather mm new kernel and small refactoring (#2040) 2025-04-14 16:37:36 -07:00
Yury Popov
e9e268336b LogCumSumExp (#2069) 2025-04-13 01:27:29 -07:00
Awni Hannun
7275ac7523 Fix release build (#2072) 2025-04-12 20:41:58 -07:00
Angelos Katharopoulos
c4189a38e4 Add float mask to sdpa vector (#2068) 2025-04-11 17:29:40 -07:00
Awni Hannun
68d1b3256b nit: fix exception handling (#2066) 2025-04-11 14:12:08 -07:00
Awni Hannun
9c6953bda7 Fix stubgen (#2065)
* Fix stubgen

* add multi optim to docs
2025-04-11 12:02:54 -07:00
Awni Hannun
ef7ece9851 fix fft bug (#2062) 2025-04-10 19:41:27 -07:00
Angelos Katharopoulos
ddaa4b7dcb Fix the test and add custom min/max reductions for uncommon MPI types (#2060) 2025-04-10 17:01:17 -07:00
Cheng
dfae2c6989 Fix MSVC build due to use of M_LN2 (#2058) 2025-04-10 07:41:41 -07:00
Anastasiia Filippova
515f104926 Min / max reductions (#2041) 2025-04-09 23:22:20 -07:00
Angelos Katharopoulos
9ecefd56db Do not load the default lib if another is requested (#2055) 2025-04-09 13:31:38 -07:00
Awni Hannun
e5d35aa187 no sdpa in grad (#2054) 2025-04-08 19:13:54 -07:00
Awni Hannun
00794c42bc Fix causal mask sdpa vec (#2053)
* fix sdpa vector causal mask

* test
2025-04-08 09:11:23 -07:00
Cheng
08a1bf3f10 Remove Event::Signal() (#2052) 2025-04-08 06:20:27 -07:00
Awni Hannun
60c4154346 Only request residency once (#2051) 2025-04-07 10:47:51 -07:00
Awni Hannun
f2c85308c1 add a half simd gemm fallback (#2046)
* add a half simd gemm fallback

* nit
2025-04-07 09:31:29 -07:00
Awni Hannun
1a28b69ee2 only add to residency set once (#2049) 2025-04-06 17:38:25 -07:00
Cheng
ba09f01ce8 Remove test of converting negative float to uint (#2048) 2025-04-06 06:21:46 -07:00
Cheng
6cf48872b7 wait_for_one should wait for task to finish (#2047) 2025-04-05 20:05:16 -07:00
Angelos Katharopoulos
7b3b8fa000 Fix ci release (#2045) 2025-04-04 20:25:01 -07:00
Awni Hannun
ec5e2aae61 nit in doc (#2044) 2025-04-04 12:04:17 -07:00
Awni Hannun
86389bf970 patch bump (#2043) 2025-04-03 13:15:18 -07:00
Jagrit Digani
3290bfa690 Add new sdpa function overload (#2035)
* Add new sdpa function overload

* Address comments

* Remove std::varaint from cpp sdpa function
2025-04-03 11:58:28 -07:00
Jagrit Digani
8777fd104f Depthwise Conv2D optimization (#2036)
- Add new specialized kernel for small kernel (kernels size <= 7), small strides (strides <= 2) depthwise 2d convolutions
- Add related tests
2025-04-03 09:42:04 -07:00
Awni Hannun
c41f7565ed fix softmax / logsumexp (#2042) 2025-04-03 08:32:59 -07:00
Awni Hannun
9ba81e3da4 tune quant dispatch (#2031) 2025-04-02 20:05:54 -07:00
Awni Hannun
c23888acd7 Fix build warning (#2033) 2025-04-01 14:42:27 -07:00
Awni Hannun
f98ce25ab9 fix residency set for real (#2032) 2025-04-01 12:59:48 -07:00
Awni Hannun
de5f38fd48 Custom logsumexp (#2028)
* initial custom logsumexp

* more tests

* comments + fix
2025-03-31 07:36:55 -07:00
Angelos Katharopoulos
ec2854b13a Swap -inf for finite_minimum value (#2029) 2025-03-30 21:55:04 -07:00
Stephen Panaro
90823d2938 Add missing funcs to docs (#2021) 2025-03-30 18:29:33 -07:00
Jesper Stemann Andersen
5f5770e3a2 Fix CPU sign for unsigned ints (#2024)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-03-30 17:56:59 -07:00
Awni Hannun
28f39e9038 Log for complex numbers in Metal (#2025)
* Log for complex numbers in Metal

* fix log2
2025-03-30 17:04:38 -07:00
Awni Hannun
b2d2b37888 fix residency set clearing (#2027) 2025-03-30 16:27:26 -07:00
Awni Hannun
fe597e141c add pinv to doc (#2020) 2025-03-30 15:54:18 -07:00
Yi Wang
72ca1539e0 Remove unused variable in /setup.py (#2026)
This is a follow up of https://github.com/ml-explore/mlx/pull/2011
2025-03-30 12:52:33 -07:00
Awni Hannun
13b26775f1 use minimum deployment target (#2016) 2025-03-28 14:31:53 -07:00
Awni Hannun
05d7118561 causal vector sdpa (#2018)
* causal vector sdpa

* get rid of memory threshold
2025-03-28 12:36:13 -07:00
Awni Hannun
98b901ad66 enable complex gemm (#2017) 2025-03-28 10:45:13 -07:00
Awni Hannun
5580b47291 iinfo and scalar overflow detection (#2009) 2025-03-27 19:54:56 -07:00
Awni Hannun
bc62932984 sdpa specialization for head dim 256 (#2007) 2025-03-27 19:31:25 -07:00
Awni Hannun
a6b5d6e759 revise cmake minimum for doctest (#2014) 2025-03-27 19:30:58 -07:00
Yi Wang
a8931306e1 Remove unused variable in CMakeBuild (#2011)
Fix https://github.com/ml-explore/mlx/issues/2010
2025-03-27 16:00:51 -07:00
Yi Wang
fecdb8717e Polish CONTRIBUTING>md (#2005) 2025-03-25 19:06:34 -07:00
Awni Hannun
916fd273ea wire cache (#2006) 2025-03-25 18:54:01 -07:00
Yi Wang
0da8506552 Update docs for extensions (#2004) 2025-03-25 18:35:03 -07:00
Cheng
eda7a7b43e Do not join threads during process exit on Windows (#1738) 2025-03-25 06:33:08 -07:00
Chunyang Wen
022eabb734 Remove unused import (#1987) 2025-03-24 20:19:32 -07:00
Awni Hannun
aba899cef8 patch bump (#2000) 2025-03-24 12:47:05 -07:00
Jagrit Digani
6a40e1c176 Fix looping limit in causal attention (#1999) 2025-03-24 12:28:00 -07:00
Jesper Stemann Andersen
9307b2ab8b Fixed 32-bit platform support for distributed/ring implementation (#1996)
Replaced unsigned long integer literals with size_t literals in ring implementation, e.g., 1UL with size_t(1).
2025-03-24 08:08:40 -07:00
Jesper Stemann Andersen
522d8d3917 Added missing netinet/in.h include that fixes build on FreeBSD (#1997)
Defines IPPROTO_TCP.
2025-03-24 08:07:34 -07:00
Awni Hannun
a84cc0123f promote mask when needed (#1998) 2025-03-23 19:58:28 -07:00
Andrey Velichkevich
f018e248cd fix(backend): Include algorithm library in Allocator (#1992)
Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
2025-03-22 21:27:51 -07:00
Awni Hannun
cfd7237a80 fix docs (#1991) 2025-03-21 19:58:53 -07:00
Angelos Katharopoulos
4eef8102c9 Distributed layers (#1270) 2025-03-21 13:52:17 -07:00
Angelos Katharopoulos
69e4dd506b Add a ring all gather (#1985) 2025-03-21 13:36:51 -07:00
Angelos Katharopoulos
25814a9458 Disable mpi on version mismatch (#1989) 2025-03-21 13:36:26 -07:00
Awni Hannun
2a980a76ce Add stats and limit to common allocator and enable tests (#1988)
* add stats to common allocator and enable tests

* linux memory and default

* fix
2025-03-21 12:28:36 -07:00
Angelos Katharopoulos
d343782c8b Cross platform libmpi loading (#1975) 2025-03-21 11:23:10 -07:00
Awni Hannun
4e1994e9d7 move memory APIs into top level mlx.core (#1982) 2025-03-21 07:25:12 -07:00
jiyzhang
65a38c452b update the formula of smooth_l1_loss (#1986) 2025-03-21 06:25:23 -07:00
Awni Hannun
7b7e2352cd fix malloc or wait deadlock (#1976) 2025-03-20 16:48:43 -07:00
Awni Hannun
1177d28395 patch bump (#1981) 2025-03-20 15:12:22 -07:00
Awni Hannun
005e7efa64 fix mask in sdpa (#1980)
* fix mask in sdpa

* fix attention mask

* Re-enable routing for array mask

---------

Co-authored-by: Jagrit Digani <digani@apple.com>
2025-03-20 14:53:12 -07:00
Jagrit Digani
b42d13ec84 Update attention tests to show diff, disable array masks (#1978) 2025-03-20 14:25:38 -07:00
Jagrit Digani
9adcd1a650 Support fused masking in Attention (#1924)
* Update API to allow mask='causal' in fast::sdpa

* Add fallback

* Update steel::AttnParams

* Fix typo

* WIP, basic causal

* Update tests

* Update benchmarking

* Update masking loop limits

* Add bool masking and update tests

* Update additive mask

* Update benchmarks

* Update benchmarks

* Update tests

* Update for bfloat error

* Update early exit

* Add random seed to tests
2025-03-20 11:01:32 -07:00
Awni Hannun
3c164fca8c Fix multistream GPU deadlock (#1969)
* fix multistream GPU deadlock

* comments
2025-03-20 07:19:47 -07:00
jiyzhang
95e335db7b Update smooth_l1_loss in losses.py (#1974)
According the definition of smooth_l1_loss, the line 

diff = predictions - targets

Should be updated to 

diff = mx.abs(predictions - targets)

After the modification, the result is consistent with PyTorch smooth_l1_loss
2025-03-19 20:19:02 -07:00
Awni Hannun
f90206ad74 Guard nullptr dereference (#1972)
* guard nullptr dereference

* comment
2025-03-19 16:24:10 -07:00
Chunyang Wen
3779150750 refactor: all use schedule (#1973) 2025-03-19 11:24:04 -07:00
Cheng
0a9777aa5c Do not define MLX_VERSION globally (#1966) 2025-03-18 07:12:40 -07:00
Chunyang Wen
45ad06aac8 Fix typo; Fix lint warning when reuse the same name (#1968)
* Fix typo; Fix lint warning when reuse the same name

* Add missing period
2025-03-18 07:12:24 -07:00
Awni Hannun
c6ea2ba329 Use same accumulation precision in gemv as gemm (#1962)
* use same accumulation precision in gemv as gemm

* faster

* fix compile
2025-03-16 07:13:24 -07:00
Awni Hannun
2770a10240 fix grad with inplace updates (#1961) 2025-03-13 19:13:09 -07:00
Awni Hannun
d2a94f9e6a Only compile warnings as errors for circle (#1957) 2025-03-12 13:08:19 -07:00
Awni Hannun
32da94507a fix vmap for flatten (#1955) 2025-03-11 10:42:22 -07:00
Awni Hannun
736a340478 reduce binary size (#1952) 2025-03-11 06:30:44 -07:00
Awni Hannun
117e1355a2 fix copy for large arrays (#1953) 2025-03-10 15:04:25 -07:00
Awni Hannun
3c3e558c60 Support transposed head/seq for kv (#1950)
* support transposed head/seq for kv

* fix flaky test

* nit
2025-03-10 10:53:45 -07:00
Chunyang Wen
cffceda6ee Add type hint for _extra_repr (#1948) 2025-03-10 06:05:36 -07:00
Chunyang Wen
048805ad2c Remove unused modules (#1949) 2025-03-10 06:05:26 -07:00
Chunyang Wen
d14c9fe7ea Add file info when raising errors in save (#1943) 2025-03-08 14:51:04 -08:00
Chunyang Wen
5db90ce822 Fix obsured warning (#1944) 2025-03-08 14:50:39 -08:00
Chunyang Wen
d699cc1330 Fix unreachable warning (#1939)
* Fix unreachable warning

* Update error message
2025-03-07 17:23:04 -08:00
Awni Hannun
c4230747a1 redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
2025-03-06 19:23:38 -08:00
Awni Hannun
5245f12a46 always use json (#1938) 2025-03-06 15:35:56 -08:00
Chunyang Wen
a198b2787e Remove unused modules (#1936) 2025-03-06 14:20:27 -08:00
Chunyang Wen
04edad8c59 Add doc string for path (#1937) 2025-03-06 14:20:09 -08:00
David Wisdom
392b3060b0 Fix typo in randint docstring (#1932)
This commit fixes a typo in the docstring for mlx.core.random.randint() by changing "roadcastable" to "broadcastable".
2025-03-05 21:48:00 -08:00
Chunyang Wen
85b34d59bc Clean unused sys (#1929) 2025-03-05 13:48:03 -08:00
Awni Hannun
f599c11bc8 bump (#1931) 2025-03-05 13:16:53 -08:00
Angelos Katharopoulos
0792ff02ff Only fail when 10 consecutive socket errors occur (#1928) 2025-03-05 13:16:19 -08:00
Alex Barron
fd0d63ba5b Affine quant always in fp32 (#1925)
* do affine quant in fp32

* static cast
2025-03-04 17:50:19 -08:00
Abe Leininger
3835a428c5 Adds nuclear norm support (#1894)
* adjust norm unit test tolerance
2025-03-04 13:26:02 -08:00
Angelos Katharopoulos
9680f72cca Add a multi optimizer (#1916) 2025-03-04 13:16:35 -08:00
Angelos Katharopoulos
a0737273d3 Allow debugging in distributed mode (#1920) 2025-03-04 13:01:10 -08:00
Awni Hannun
e613d0eaf0 SDPA support for small batch (over sequence) queries (#1922)
* batch query sdpa

* batch sdpa for query
2025-03-04 10:59:04 -08:00
Awni Hannun
6bcd6bcf70 fix donation in scan (#1917) 2025-03-03 11:30:59 -08:00
Awni Hannun
ba12e4999a Use a heap for small sizes (#1911)
* use a heap for small sizes

* check if VM
2025-03-03 06:50:57 -08:00
Awni Hannun
4e7cd31d12 Fix slice data size (#1913)
* fix slice data size

* add test
2025-03-02 21:50:42 -08:00
Angelos Katharopoulos
5e6c130d93 RMS norm without scaling (#1915) 2025-02-28 20:26:57 -08:00
Angelos Katharopoulos
5d68082881 Ring docs (#1829) 2025-02-28 11:34:21 -08:00
Angelos Katharopoulos
607181644f Add mlx.distributed_config script (#1902) 2025-02-28 11:16:39 -08:00
Jagrit Digani
89d327075f Enabling fused attention for head dim 128 (#1899)
* Share KV smem

* Fix bfloat error

* Unroll O = S @ V loop

* Perf upgrade

* Remove commented out function

* Add -Wno-c++17-extensions flag to metal flags

* Add -Wno-c++17-extensions flag to metal extension flags
2025-02-26 10:02:06 -08:00
Angelos Katharopoulos
6bf00ef631 Fix ring of 2 and allow scalars in API (#1906) 2025-02-25 17:03:01 -08:00
Awni Hannun
7d042f17fe Double for lapack (#1904)
* double for lapack ops

* add double support for lapack ops
2025-02-25 11:39:36 -08:00
Awni Hannun
28b8079e30 fix double type promotion (#1901) 2025-02-25 06:00:53 -08:00
Awni Hannun
7face5d9fd fix cpu compile (#1897) 2025-02-24 14:10:30 -08:00
Awni Hannun
a44dc4bdb0 fix leaking objc (#1898) 2025-02-24 13:57:59 -08:00
Awni Hannun
2d0f384b6f fix simd erf_inv (#1896) 2025-02-24 13:57:47 -08:00
Awni Hannun
8ff84b5c43 fix version and expose command queue getter (#1892) 2025-02-20 15:25:15 -08:00
Angelos Katharopoulos
10b271d963 Ring update (#1885) 2025-02-20 14:32:31 -08:00
Jesper Stemann Andersen
0ebc8a3d25 Fixed issue where Clang on FreeBSD failed to compile mlx/backend/cpu/quantized.cpp (#1890) 2025-02-20 12:02:12 -08:00
Awni Hannun
bbda0fdbdb Allow non-square lu (#1889) 2025-02-20 08:13:23 -08:00
Jesper Stemann Andersen
c86422bdd4 Added mlx::core::version() returning std::string(MLX_VERSION) (#1819)
* Added version.h providing mlx::core::version() returning std::string(MLX_VERSION)

Also, added MLX_VERSION_MAJOR, MLX_VERSION_MINOR, MLX_VERSION_PATCH, MLX_VERSION_NUMERIC, and accompanying functions.

* Added version.h to mlx.h

* Changed version int functions to be constexpr

* Formatting

* Added handling of MLX_VERSION where only the prefix has major.minor.patch format

* Changed version function to be constexpr
2025-02-19 20:30:19 -08:00
Awni Hannun
c707b2b0a6 Limit compile buffers (#1887)
* limit compile buffers

* maybe not flaky test
2025-02-19 20:28:13 -08:00
Angelos Katharopoulos
78ba24c37d Raise an exception in the rope op if input is integer (#1884) 2025-02-19 14:43:39 -08:00
Angelos Katharopoulos
1a2cb72030 Ensure linspace always contains start and stop (#1883) 2025-02-19 13:53:20 -08:00
Abe Leininger
344a29506e Enforce triangular matrix form in tri_inv (#1876)
* fix tri_inv bug

* Revert "fix tri_inv bug"

This reverts commit b74b290201.

* Make sure that tri_inv returns a triangular matrix

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-02-19 12:42:33 -08:00
Angelos Katharopoulos
71de73a668 Fix convs by reverting #1803 (#1882) 2025-02-18 14:36:34 -08:00
570 changed files with 50031 additions and 13231 deletions

View File

@@ -7,15 +7,9 @@ parameters:
nightly_build:
type: boolean
default: false
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
jobs:
build_documentation:
@@ -24,13 +18,14 @@ jobs:
type: boolean
default: false
macos:
xcode: "15.2.0"
resource_class: macos.m1.medium.gen1
xcode: "26.0.0"
resource_class: m4pro.medium
steps:
- checkout
- run:
name: Install
command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.9
brew install doxygen
python3.9 -m venv env
@@ -38,7 +33,7 @@ jobs:
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
@@ -70,9 +65,9 @@ jobs:
git push -f origin gh-pages
linux_build_and_test:
docker:
- image: cimg/python:3.9
machine:
image: ubuntu-2204:current
resource_class: large
steps:
- checkout
- run:
@@ -84,34 +79,37 @@ jobs:
- run:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e ".[dev]" -v
- run:
name: Generate package stubs
command: |
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
source .venv/bin/activate
python -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run:
name: Build CPP only
command: |
mkdir -p build && cd build
source .venv/bin/activate
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
- run:
@@ -122,58 +120,64 @@ jobs:
parameters:
xcode_version:
type: string
default: "15.2.0"
default: "26.0.0"
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m4pro.medium
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.9
brew install openmpi
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
xcodebuild -downloadComponent MetalToolchain
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi uv
- run:
name: Install Python package
command: |
source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
uv venv --python 3.9
uv pip install \
nanobind==2.4.0 \
cmake \
numpy \
torch \
tensorflow \
unittest-xml-reporting
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- run:
name: Run Python tests
command: |
source env/bin/activate
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run:
name: Build example extension
command: |
source env/bin/activate
source .venv/bin/activate
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext -j8
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project python test.py
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source env/bin/activate
source .venv/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
@@ -182,7 +186,7 @@ jobs:
- run:
name: Build small binary
command: |
source env/bin/activate
source .venv/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
@@ -194,13 +198,74 @@ jobs:
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
uv run --no-project python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
cuda_build_and_test:
parameters:
image_date:
type: string
default: "2023.11.1"
machine:
image: "linux-cuda-12:<< parameters.image_date >>"
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- restore_cache:
keys:
- cuda-<< parameters.image_date >>-{{ arch }}-
- run:
name: Install dependencies
command: |
sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libnccl2 libnccl-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v
- run:
name: Run Python tests
command: |
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=`which nvcc` \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- run:
name: CCache report
command: |
ccache --show-stats
ccache --zero-stats
ccache --max-size 400MB
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
paths:
- /home/circleci/.cache/ccache
build_release:
parameters:
@@ -209,23 +274,32 @@ jobs:
default: "3.9"
xcode_version:
type: string
default: "15.2.0"
default: "26.0.0"
build_env:
type: string
default: ""
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
resource_class: m4pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
xcodebuild -downloadComponent MetalToolchain
mkdir -p ~/miniconda3
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
source ~/miniconda3/bin/activate
conda init --all
conda create -n env python=<< parameters.python_version >> -y
conda activate env
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
@@ -235,30 +309,38 @@ jobs:
- run:
name: Install Python package
command: |
source env/bin/activate
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
conda activate env
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
pip install . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
conda activate env
pip install typing_extensions
python setup.py generate_stubs
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
python -m build -w
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
source env/bin/activate
conda activate env
twine upload dist/*
- store_artifacts:
path: dist/
@@ -268,52 +350,100 @@ jobs:
python_version:
type: string
default: "3.9"
extra_env:
build_env:
type: string
default: "DEV_RELEASE=1"
docker:
- image: ubuntu:20.04
default: ""
machine:
image: ubuntu-2204:current
resource_class: large
steps:
- checkout
- run:
name: Build wheel
command: |
PYTHON=python<< parameters.python_version >>
apt-get update
apt-get upgrade -y
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
apt-get install -y apt-utils
apt-get install -y software-properties-common
add-apt-repository -y ppa:deadsnakes/ppa
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get install -y build-essential git
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
TZ=Etc/UTC sudo apt-get -y install tzdata
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v
<< parameters.build_env >> pip install ".[dev]" -v
pip install typing_extensions
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
python setup.py generate_stubs
python setup.py clean --all
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
bash python/scripts/repair_linux.sh
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload packages
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
build_cuda_release:
parameters:
build_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: xlarge
steps:
- checkout
- run:
name: Upload package
name: Build wheel
command: |
source env/bin/activate
twine upload wheelhouse/*
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip
pip install auditwheel
pip install patchelf
pip install build
pip install twine
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build -w
bash python/scripts/repair_cuda.sh
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
@@ -325,21 +455,23 @@ workflows:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test
- cuda_build_and_test:
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
- build_documentation
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- build_release:
@@ -351,8 +483,9 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["26.0.0"]
- build_documentation:
filters:
tags:
@@ -360,6 +493,25 @@ workflows:
branches:
ignore: /.*/
upload-docs: true
- build_linux_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
build_env: ["PYPI_RELEASE=1"]
prb:
when:
@@ -375,9 +527,14 @@ workflows:
requires: [ hold ]
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test:
requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
nightly_build:
when:
and:
@@ -388,27 +545,33 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
weekly_build:
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
- build_cuda_release
build_dev_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.weekly_build >>
- << pipeline.parameters.test_release >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.linux_release >>
jobs:
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]
build_env: ["DEV_RELEASE=1"]
- build_cuda_release:
matrix:
parameters:
build_env: ["DEV_RELEASE=1"]

1
.gitignore vendored
View File

@@ -36,6 +36,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
uv.lock
# vim
*.swp

View File

@@ -19,11 +19,17 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>
# Organizations
MLX has received contributions from the following companies:
- NVIDIA Corporation & Affiliates
# Third-Party Software
MLX leverages several third-party software, listed here together with

View File

@@ -1,6 +1,24 @@
cmake_minimum_required(VERSION 3.25)
project(mlx LANGUAGES C CXX)
if(NOT MLX_VERSION)
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
set(_major ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
set(_minor ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
set(_patch ${CMAKE_MATCH_1})
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
set(MLX_VERSION ${MLX_PROJECT_VERSION})
else()
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
${MLX_VERSION})
endif()
project(
mlx
LANGUAGES C CXX
VERSION ${MLX_PROJECT_VERSION})
# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
@@ -16,18 +34,16 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.23.0)
endif()
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
# --------------------- Processor tests -------------------------
message(
@@ -50,10 +66,17 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
endif()
else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif()
if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
endif()
endif()
# ----------------------------- Lib -----------------------------
@@ -63,7 +86,6 @@ include(FetchContent)
cmake_policy(SET CMP0135 NEW)
add_library(mlx)
set_target_properties(mlx PROPERTIES COMPILE_WARNING_AS_ERROR ON)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
@@ -71,6 +93,10 @@ if(MLX_BUILD_METAL)
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
@@ -114,6 +140,12 @@ elseif(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
@@ -200,23 +232,13 @@ else()
set(MLX_BUILD_ACCELERATE OFF)
endif()
find_package(MPI)
if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
endif()
endif()
message(STATUS "Downloading json")
FetchContent_Declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
@@ -224,12 +246,19 @@ target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>)
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
# Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
if(USE_SYSTEM_FMT)
find_package(fmt REQUIRED)
else()
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
endif()
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)

View File

@@ -5,26 +5,26 @@ possible.
## Pull Requests
1. Fork and submit pull requests to the repo.
1. Fork and submit pull requests to the repo.
2. If you've added code that should be tested, add tests.
3. If a change is likely to impact efficiency, run some of the benchmarks before
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
4. If you've changed APIs, update the documentation.
5. Every PR should have passing tests and at least one review.
5. Every PR should have passing tests and at least one review.
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
This should install hooks for running `black` and `clang-format` to ensure
consistent style for C++ and python code.
You can also run the formatters manually as follows:
```
clang-format -i file.cpp
```
```
black file.py
```
```shell
clang-format -i file.cpp
```
```shell
black file.py
```
or run `pre-commit run --all-files` to check all files in the repo.
## Issues

View File

@@ -1,4 +1,6 @@
include CMakeLists.txt
include mlx.pc.in
recursive-include mlx/ *
include cmake/*
include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561

View File

@@ -11,10 +11,10 @@ brought to you by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
@@ -68,18 +68,23 @@ in the documentation.
## Installation
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
macOS, run:
**With `pip`**:
```
```bash
pip install mlx
```
**With `conda`**:
To install the CUDA backend on Linux, run:
```bash
pip install mlx[cuda]
```
conda install -c conda-forge mlx
To install a CPU-only Linux package, run:
```bash
pip install mlx[cpu]
```
Checkout the

View File

@@ -1,5 +1,6 @@
// Copyright © 2023 Apple Inc.
#include <cstring>
#include <iostream>
#include <sstream>

View File

@@ -192,6 +192,22 @@ void time_reductions() {
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
TIME(argmin_along_1);
auto indices = mx::array({1});
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
std::vector<int> axes{0};
auto b = scatter(a, {indices}, updates, axes);
mx::eval(b);
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
TIME(max_along_0);
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
TIME(max_along_1);
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
TIME(min_along_0);
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
TIME(min_along_1);
}
void time_gather_scatter() {

View File

@@ -5,6 +5,7 @@ import os
import time
import torch
import torch.cuda
import torch.mps
@@ -44,8 +45,10 @@ def bench(f, *args):
def sync_if_needed(x):
if x.device != torch.device("cpu"):
if x.device == torch.device("mps"):
torch.mps.synchronize()
elif x.device == torch.device("cuda"):
torch.cuda.synchronize()
@torch.no_grad()
@@ -99,6 +102,14 @@ def reduction(op, axis, x):
sync_if_needed(x)
@torch.no_grad()
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
sync_if_needed(x)
@torch.no_grad()
def softmax(axis, x):
ys = []
@@ -340,7 +351,11 @@ if __name__ == "__main__":
args.axis.pop(0)
torch.set_num_threads(1)
device = "cpu" if args.cpu else "mps"
device = "mps"
if torch.cuda.is_available():
device = "cuda"
if args.cpu:
device = "cpu"
types = args.dtype
if not types:
@@ -460,5 +475,8 @@ if __name__ == "__main__":
elif args.benchmark == "selu":
print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else:
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")

View File

@@ -0,0 +1,107 @@
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
dtype = "float32"
shapes = (
(4, 32, 32, 21, 3, 3, 128),
(4, 32, 32, 21, 3, 3, 37),
(4, 32, 32, 370, 3, 3, 370),
(4, 32, 32, 370, 7, 7, 128),
(2, 320, 640, 21, 7, 7, 21),
)
for N, H, W, C, kh, kw, O in shapes:
time_mlx, time_torch = bench_shape(
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,7 +1,6 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
from time import time
import mlx.core as mx
import torch

View File

@@ -0,0 +1,74 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_mm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = x @ w1.T
x = x @ w2.T
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_mm()

View File

@@ -0,0 +1,84 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate(
[
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
for i, j in enumerate(idx.tolist())
],
axis=0,
)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_qmm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = mx.quantized_matmul(x, *w1, transpose=True)
x = mx.quantized_matmul(x, *w2, transpose=True)
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_qmm()

View File

@@ -1,5 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from functools import partial
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
@@ -10,32 +12,71 @@ def layer_norm(x, w, b, eps):
x = x.astype(mx.float32)
mu = mx.mean(x, -1, keepdims=True)
v = mx.var(x, -1, keepdims=True)
return (x - mu) * mx.rsqrt(v + eps) * w + b
y = (x - mu) * mx.rsqrt(v + eps)
if w is not None:
y = y * w
if b is not None:
y = y + b
return y
def time_layer_norm():
def time_layer_norm(N, dt):
L = 1024
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
mx.eval(x, w, b, y)
def layer_norm_loop(g, x, w, b):
def layer_norm_loop(f, x, w, b):
for _ in range(32):
x = f(x, w, b)
return x
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
def layer_norm_grad_loop(g, x, w, b):
gx, gw, gb = x, w, b
for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb
time_fn(layer_norm_loop, g1, x, w, b)
time_fn(layer_norm_loop, g2, x, w, b)
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
time_fn(layer_norm_grad_loop, g1, x, w, b)
time_fn(layer_norm_grad_loop, g2, x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
mx.eval(x, w, b, y)
def layer_norm_grad_x_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(layer_norm_grad_x_loop, g1, x)
time_fn(layer_norm_grad_x_loop, g2, x)
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
if __name__ == "__main__":
time_layer_norm()
for dt in [mx.float32, mx.float16, mx.bfloat16]:
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
print(dt, n)
time_layer_norm(n, dt)

View File

@@ -9,7 +9,10 @@ def rms_norm(x, w, eps):
ot = x.dtype
x = x.astype(mx.float32)
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return (x * n).astype(ot) * w
y = (x * n).astype(ot)
if w is not None:
y = y * w
return y
def time_rms_norm():
@@ -34,6 +37,27 @@ def time_rms_norm():
time_fn(rms_norm_loop, mx.compile(g1), x, w)
time_fn(rms_norm_loop, mx.compile(g2), x, w)
f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(rms_norm_loop, g1, x)
time_fn(rms_norm_loop, g2, x)
time_fn(rms_norm_loop, mx.compile(g1), x)
time_fn(rms_norm_loop, mx.compile(g2), x)
if __name__ == "__main__":
time_rms_norm()

View File

@@ -28,11 +28,34 @@ def bench(f, *args):
return (e - s) * 1e-9
def mlx_sdpa_fused_inner(q, k, v, scale):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
np_dtype = getattr(np, dtype)
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
scale = 1.0 / math.sqrt(D)
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
if mask is not None:
if mask == "additive":
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
mask = mx.array(mask_np)
elif mask == "bool":
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
mask = mx.array(mask_np)
return q_mx, k_mx, v_mx, scale, mask
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
@@ -41,6 +64,7 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
B = q.shape[0]
L = q.shape[2]
kL = k.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
@@ -48,10 +72,27 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if f32softmax:
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
else:
scores = mx.softmax(scores, axis=-1)
if mask is not None:
if mask == "causal":
q_offset = max(0, kL - L)
q_indices = mx.arange(q_offset, q_offset + L)
k_indices = mx.arange(kL)
mask = q_indices[:, None] >= k_indices[None]
if n_repeats > 1 and mask.ndim >= 3:
if mask.shape[-3] == 1:
mask = mx.expand_dims(mask, -3)
else:
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
if mask.dtype == mx.bool_:
scores = mx.where(mask, scores, -np.float32(np.inf))
else:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = scores @ v
if n_repeats > 1:
@@ -60,74 +101,55 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
return out
def mlx_spda_unfused(q, k, v, scale, transpose):
q_out = q
def mlx_fused_attn(q, k, v, scale, mask):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
q_t = mx.transpose(q, (0, 2, 1, 3))
k_t = mx.transpose(k, (0, 2, 1, 3))
v_t = mx.transpose(v, (0, 2, 1, 3))
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
return mx.transpose(o_t, (0, 2, 1, 3))
else:
return f(q, k, v, scale=scale, mask=mask)
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
q_out = q
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
mx.eval(q_out)
return q_out
def mlx_spda_fused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
shape_q = (
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
)
shape_kv = (
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
def bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
):
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
)
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
time_mlx_unfused = bench(
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
time_mlx_fused = bench(
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
scale = math.sqrt(1.0 / head_dim)
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
o_mlx_unfused = do_attention(
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
atol = 1e-5 if dtype == "float32" else 2e-4
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
if transpose:
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
)
return time_mlx_fused, time_mlx_unfused
@@ -151,39 +173,51 @@ if __name__ == "__main__":
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 32),
( 1, 2048, 2048, 64, 32, 32),
( 1, 4096, 4096, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 8),
( 1, 2048, 2048, 64, 32, 8),
( 1, 4096, 4096, 64, 32, 8),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 32),
( 1, 2048, 2048, 80, 32, 32),
( 1, 4096, 4096, 80, 32, 32),
( 1, 1024, 1024, 80, 32, 8),
( 1, 2048, 2048, 80, 32, 8),
( 1, 4096, 4096, 80, 32, 8),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 32),
( 1, 2048, 2048, 128, 32, 32),
( 1, 4096, 4096, 128, 32, 32),
( 1, 1024, 1024, 128, 32, 8),
( 1, 2048, 2048, 128, 32, 8),
( 1, 4096, 4096, 128, 32, 8),
)
# fmt: on
shapes = shapes_64 + shapes_80 + shapes_128
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
masks = [None, "bool", "causal"]
print(
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
)
for dtype in dtypes:
for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
np_dtype = getattr(np, dtype)
time_mlx_fused, time_mlx_unfused = bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)
for mask_in in masks:
time_mlx_fused, time_mlx_unfused = bench_shape(
B,
qsl,
ksl,
head_dim,
n_q_heads,
n_kv_heads,
dtype,
transpose,
mask_in,
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)

View File

@@ -51,6 +51,20 @@ def time_maximum():
time_fn(mx.maximum, a, b)
def time_max():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.max, a, 0)
def time_min():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.min, a, 0)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
@@ -108,6 +122,8 @@ if __name__ == "__main__":
time_add()
time_matmul()
time_min()
time_max()
time_maximum()
time_exp()
time_negative()

54
cmake/FindNCCL.cmake Normal file
View File

@@ -0,0 +1,54 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@@ -1,5 +1,7 @@
include(CMakeParseArguments)
# clang format off
#
# ##############################################################################
# Build metal library
#
@@ -9,11 +11,14 @@ include(CMakeParseArguments)
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# files (like headers)
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
#
# clang format on
macro(mlx_build_metallib)
# Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
@@ -21,7 +26,11 @@ macro(mlx_build_metallib)
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
# Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
-frecord-sources)
endif()
# Prepare metallib build command
add_custom_command(

View File

@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = YES
GENERATE_HTML = NO
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES

View File

@@ -1,4 +1,5 @@
sphinx
breathe
sphinx-book-theme
sphinx-copybutton
mlx

View File

@@ -10,7 +10,7 @@ import mlx.core as mx
# -- Project information -----------------------------------------------------
project = "MLX"
copyright = "2023, MLX Contributors"
copyright = "2023, Apple"
author = "MLX Contributors"
version = ".".join(mx.__version__.split(".")[:3])
release = version
@@ -18,6 +18,7 @@ release = version
# -- General configuration ---------------------------------------------------
extensions = [
"sphinx_copybutton",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",

View File

@@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example
--------------
.. currentmodule:: mlx.core
Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
def exp_elementwise(a: mx.array):
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
@@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Every time you make a kernel, a new Metal library is created and possibly
JIT compiled. To reduce the overhead from that, build the kernel once with
:func:`fast.metal_kernel` and then use it many times.
.. note::
We are only required to pass the body of the Metal kernel in ``source``.
Only pass the body of the Metal kernel in ``source``. The function
signature is generated automatically.
The full function signature will be generated using:
@@ -78,44 +86,52 @@ Putting this all together, the generated function signature for ``myexp`` is as
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
``threadgroup`` size threadgroups. For optimal performance, each thread group
dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
generated code for debugging purposes.
Using Shape/Strides
-------------------
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
when indexing.
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
is ``True`` by default. This will copy the array inputs if needed
before the kernel is launched to ensure that the memory layout is row
contiguous. Generally this makes writing the kernel easier, since we don't
have to worry about gaps or the ordering of the dims when indexing.
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
input array ``a`` if any are present in ``source``.
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
present in ``source``. We can then use MLX's built in indexing utils to fetch
the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
Let's convert ``myexp`` above to support arbitrarily strided arrays without
relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source,
ensure_row_contiguous=False,
)
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
@@ -123,7 +139,6 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
ensure_row_contiguous=False,
)
return outputs[0]
@@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops:
.. code-block:: python
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output
return output
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel:
.. code-block:: python
@mx.custom_function
def grid_sample(x, grid):
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
assert D == 2, "Last dim of `grid` must be size 2."
int ix_nw = floor(ix);
int iy_nw = floor(iy);
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
@mx.custom_function
def grid_sample(x, grid):
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs[0]
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
assert D == 2, "Last dim of `grid` must be size 2."
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs[0]
For a reasonably sized input such as:
.. code-block:: python
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement:
@@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
Grid Sample VJP
---------------
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
its custom vjp transform so MLX can differentiate it.
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
define its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra ``mx.fast.metal_kernel`` features:
requires a few extra :func:`fast.metal_kernel` features:
* ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
@@ -299,128 +316,129 @@ We can then implement the backwards pass as follows:
.. code-block:: python
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
assert D == 2, "Last dim of `grid` must be size 2."
int gH = grid_shape[1];
int gW = grid_shape[2];
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
int gH = grid_shape[1];
int gW = grid_shape[2];
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T gix_mult = W / 2;
T giy_mult = H / 2;
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
T gix_mult = W / 2;
T giy_mult = H / 2;
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
assert D == 2, "Last dim of `grid` must be size 2."
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
There's an even larger speed up for the vjp:

View File

@@ -22,12 +22,12 @@ You can do that in MLX directly:
This function performs that operation while leaving the implementation and
function transformations to MLX.
However you may need to customize the underlying implementation, perhaps to
make it faster or for custom differentiation. In this tutorial we will go
through adding custom extensions. It will cover:
However, you may want to customize the underlying implementation, perhaps to
make it faster. In this tutorial we will go through adding custom extensions.
It will cover:
* The structure of the MLX library.
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
* Implementing a CPU operation.
* Implementing a GPU operation using metal.
* Adding the ``vjp`` and ``jvp`` function transformation.
* Building a custom extension and binding it to python.
@@ -45,7 +45,7 @@ Operations
Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
C++:
@@ -55,7 +55,7 @@ C++:
* Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
* Use NumPy-style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
@@ -66,7 +66,7 @@ C++:
StreamOrDevice s = {} // Stream on which to schedule the operation
);
The simplest way to this operation is in terms of existing operations:
The simplest way to implement this is with existing operations:
.. code-block:: C++
@@ -93,9 +93,9 @@ Primitives
^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create outputs arrays given a input arrays. Further, a
defines how to create output arrays given input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
more concrete:
.. code-block:: C++
@@ -128,7 +128,7 @@ more concrete:
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
@@ -138,13 +138,13 @@ more concrete:
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
/** Print the primitive. */
void print(std::ostream& os) override {
os << "Axpby";
/** The name of primitive. */
const char* name() const override {
return "Axpby";
}
/** Equivalence check **/
@@ -153,9 +153,6 @@ more concrete:
private:
float alpha_;
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, array& out);
};
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
@@ -188,7 +185,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = is_floating_point(promoted_dtype)
auto out_dtype = issubdtype(promoted_dtype, float32)
? promoted_dtype
: promote_types(promoted_dtype, float32);
@@ -234,49 +231,57 @@ the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
Implementing the CPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by implementing a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
Let's start by implementing :meth:`Axpby::eval_cpu`.
Our naive method will go over each element of the output array, find the
The method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
point-wise. This is captured in the templated function :meth:`axpby_impl`.
.. code-block:: C++
template <typename T>
void axpby_impl(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// We only allocate memory when we are ready to fill the output
// malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(allocator::malloc_or_wait(out.nbytes()));
template <typename T>
void axpby_impl(
const mx::array& x,
const mx::array& y,
mx::array& out,
float alpha_,
float beta_,
mx::Stream stream) {
out.set_data(mx::allocator::malloc(out.nbytes()));
// Collect input and output data pointers
const T* x_ptr = x.data<T>();
const T* y_ptr = y.data<T>();
T* out_ptr = out.data<T>();
// Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(y);
encoder.set_output_array(out);
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Launch the CPU kernel
encoder.dispatch([x_ptr = x.data<T>(),
y_ptr = y.data<T>(),
out_ptr = out.data<T>(),
size = out.size(),
shape = out.shape(),
x_strides = x.strides(),
y_strides = y.strides(),
alpha_,
beta_]() {
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
}
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < size; out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
});
}
Our implementation should work for all incoming floating point arrays.
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
@@ -284,112 +289,32 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
.. code-block:: C++
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == float16) {
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == bfloat16) {
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == complex64) {
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"[Axpby] Only supports floating point types.");
}
}
This is good as a fallback implementation. We can use the ``axpby`` routine
provided by the Accelerate_ framework for a faster implementation in certain
cases:
#. Accelerate does not provide implementations of ``axpby`` for half precision
floats. We can only use it for ``float32`` types.
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
elements have fixed strides between them. We only direct to Accelerate
if both ``x`` and ``y`` are row contiguous or column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
MLX expects to write the output to a new array. We must copy the elements
of ``y`` into the output and use that as an input to ``axpby``.
Let's write an implementation that uses Accelerate in the right conditions.
It allocates data for the output, copies ``y`` into it, and then calls the
:func:`catlas_saxpby` from accelerate.
.. code-block:: C++
template <typename T>
void axpby_impl_accelerate(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
T* y_ptr = out.data<T>();
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Call the inplace accelerate operator
catlas_saxpby(
/* N = */ out.size(),
/* ALPHA = */ alpha,
/* X = */ x_ptr,
/* INCX = */ 1,
/* BETA = */ beta,
/* Y = */ y_ptr,
/* INCY = */ 1);
}
For inputs that do not fit the criteria for accelerate, we fall back to
:meth:`Axpby::eval`. With this in mind, let's finish our
:meth:`Axpby::eval_cpu`.
.. code-block:: C++
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
return;
}
// Fall back to common back-end if specializations are not available
eval(inputs, outputs);
// Dispatch to the correct dtype
if (out.dtype() == mx::float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
} else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
}
}
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
primitive here.
Implementing the GPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -466,17 +391,17 @@ below.
auto& d = metal::device(s.device);
// Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
std::stream kname;
kname = "axpby_general_" + type_to_name(out);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Load the metal library
auto lib = d.get_library("mlx_ext", current_binary_dir());
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
auto kernel = d.get_kernel(kname, lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -544,7 +469,7 @@ one we just defined:
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops
// The jvp transform on the primitive can be built with ops
// that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the
@@ -556,7 +481,7 @@ one we just defined:
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
}
// If, argnums = {0, 1}, we take contributions from both
// If argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
@@ -810,7 +735,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")
print(f"c is correct: {mx.all(c == 6.0).item()}")
Output:
@@ -818,13 +743,13 @@ Output:
c shape: [3, 4]
c dtype: float32
c correctness: True
c is correct: True
Results
^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we first defined on the CPU.
with the naive :meth:`simple_axpby` we first defined.
.. code-block:: python
@@ -832,13 +757,11 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
from mlx_sample_extensions import axpby
import time
mx.set_default_device(mx.cpu)
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
M = 256
N = 512
M = 4096
N = 4096
x = mx.random.normal((M, N))
y = mx.random.normal((M, N))
@@ -849,24 +772,24 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
def bench(f):
# Warm up
for i in range(100):
for i in range(5):
z = f(x, y, alpha, beta)
mx.eval(z)
# Timed run
s = time.time()
for i in range(5000):
for i in range(100):
z = f(x, y, alpha, beta)
mx.eval(z)
e = time.time()
return e - s
return 1000 * (e - s) / 100
simple_time = bench(simple_axpby)
custom_time = bench(axpby)
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see
modest improvements right away!
This operation is now good to be used to build other operations, in

View File

@@ -70,6 +70,8 @@ are the CPU and GPU.
python/fft
python/linalg
python/metal
python/cuda
python/memory_management
python/nn
python/optimizers
python/distributed

View File

@@ -13,7 +13,7 @@ silicon computer is
pip install mlx
To install from PyPI you must meet the following requirements:
To install from PyPI your system must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.9
@@ -23,12 +23,39 @@ To install from PyPI you must meet the following requirements:
MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma)
CUDA
^^^^
MLX is also available on conda-forge. To install MLX with conda do:
MLX has a CUDA backend which you can install with:
.. code-block:: shell
conda install conda-forge::mlx
pip install mlx[cuda]
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.0 (Volta)
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.9
CPU-only (Linux)
^^^^^^^^^^^^^^^^
For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell
pip install mlx[cpu]
To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.9
Troubleshooting
@@ -65,6 +92,8 @@ Build Requirements
Python API
^^^^^^^^^^
.. _python install:
To build and install the MLX python library from source, first, clone MLX from
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
@@ -76,20 +105,20 @@ Then simply build and install MLX using pip:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
pip install .
For developing, install the package with development dependencies, and use an
editable install:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
pip install -e ".[dev]"
Once the development dependencies are installed, you can build faster with:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
python setup.py build_ext --inplace
Run the tests with:
@@ -107,6 +136,8 @@ IDE:
C++ API
^^^^^^^
.. _cpp install:
Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start
@@ -185,6 +216,7 @@ should point to the path to the built metal library.
xcrun -sdk macosx --show-sdk-version
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
@@ -213,6 +245,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots.
Linux
^^^^^
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
For example on Ubuntu, run the following:
.. code-block:: shell
apt-get update -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
From here follow the instructions to install either the :ref:`Python <python
install>` or :ref:`C++ <cpp install>` APIs.
CUDA
^^^^
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
and the CUDA toolkit. For example on Ubuntu, run the following:
.. code-block:: shell
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
When building either the Python or C++ APIs make sure to pass the cmake flag
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
.. code-block:: shell
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run:
.. code-block:: shell
mkdir -p build && cd build
cmake .. -DMLX_BUILD_CUDA=ON && make -j
Troubleshooting
^^^^^^^^^^^^^^^

View File

@@ -19,6 +19,8 @@ Array
array.ndim
array.shape
array.size
array.real
array.imag
array.abs
array.all
array.any
@@ -38,6 +40,7 @@ Array
array.log10
array.log1p
array.log2
array.logcumsumexp
array.logsumexp
array.max
array.mean

9
docs/src/python/cuda.rst Normal file
View File

@@ -0,0 +1,9 @@
CUDA
=====
.. currentmodule:: mlx.core.cuda
.. autosummary::
:toctree: _autosummary
is_available

View File

@@ -13,3 +13,4 @@ Fast
rope
scaled_dot_product_attention
metal_kernel
cuda_kernel

View File

@@ -20,3 +20,5 @@ FFT
irfft2
rfftn
irfftn
fftshift
ifftshift

View File

@@ -16,9 +16,12 @@ Linear Algebra
cross
qr
svd
eigvals
eig
eigvalsh
eigh
lu
lu_factor
pinv
solve
solve_triangular

View File

@@ -0,0 +1,16 @@
Memory Management
=================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache

View File

@@ -8,13 +8,5 @@ Metal
is_available
device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture
stop_capture

View File

@@ -174,6 +174,7 @@ In detail:
value_and_grad
quantize
average_gradients
.. toctree::

View File

@@ -36,10 +36,12 @@ Operations
bitwise_or
bitwise_xor
block_masked_mm
broadcast_arrays
broadcast_to
ceil
clip
concatenate
contiguous
conj
conjugate
convolve
@@ -101,6 +103,7 @@ Operations
log10
log1p
logaddexp
logcumsumexp
logical_not
logical_and
logical_or

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", state)
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
state = tree_unflatten(mx.load("optimizer.safetensors"))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -18,3 +18,5 @@ Common Optimizers
AdamW
Adamax
Lion
MultiOptimizer
Muon

View File

@@ -9,6 +9,7 @@ Transforms
:toctree: _autosummary
eval
async_eval
compile
custom_function
disable_compile

View File

@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]

View File

@@ -5,21 +5,27 @@ Distributed Communication
.. currentmodule:: mlx.core.distributed
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
provide distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. You can
see a list of the supported operations in the :ref:`API docs<distributed>`.
MLX supports distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. At the
moment we support two different communication backends:
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
full-featured and mature distributed communications library
* A **ring** backend of our own that uses native TCP sockets and should be
faster for thunderbolt connections.
The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs<distributed>`.
.. note::
A lot of operations may not be supported or not as fast as they should be.
Some operations may not be supported or not as fast as they should be.
We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX.
Getting Started
---------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. The minimal distributed program in MLX is as simple as:
A distributed program in MLX is as simple as:
.. code:: python
@@ -30,74 +36,79 @@ machine. The minimal distributed program in MLX is as simple as:
print(world.rank(), x)
The program above sums the array ``mx.ones(10)`` across all
distributed processes. If simply run with ``python``, however, only one
process is launched and no distributed communication takes place.
distributed processes. However, when this script is run with ``python`` only
one process is launched and no distributed communication takes place. Namely,
all operations in ``mx.distributed`` are noops when the distributed group has a
size of one. This property allows us to avoid code that checks if we are in a
distributed setting similar to the one below:
To launch the program in distributed mode we need to use ``mpirun`` or
``mpiexec`` depending on the MPI installation. The simplest possible way is the
following:
.. code:: python
import mlx.core as mx
x = ...
world = mx.distributed.init()
# No need for the check we can simply do x = mx.distributed.all_sum(x)
if world.size() > 1:
x = mx.distributed.all_sum(x)
Running Distributed Programs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
MLX provides ``mlx.launch`` a helper script to launch distributed programs.
Continuing with our initial example we can run it on localhost with 4 processes using
.. code:: shell
$ mpirun -np 2 python test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
$ mlx.launch -n 4 my_script.py
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
print 4 etc.
Installing MPI
---------------
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
We can also run it on some remote hosts by providing their IPs (provided that
the script exists on all hosts and they are reachable by ssh)
.. code:: shell
$ conda install conda-forge::openmpi
$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
Consult the dedicated :doc:`usage guide<launching_distributed>` for more
information on using ``mlx.launch``.
.. code:: shell
Selecting Backend
^^^^^^^^^^^^^^^^^
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
Setting up Remote Hosts
-----------------------
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
full path to force all machines to use a specific path.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
You can select the backend you want to use when calling :func:`init` by passing
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
both fail then a singleton group is created.
.. note::
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
After a distributed backend is successfully initialized :func:`init` will
return **the same backend** if called without arguments or with backend set to
``any``.
An easy way to pass the host names to MPI is using a host file. A host file
looks like the following, where ``host1`` and ``host2`` should be the fully
qualified domain names or IPs for these hosts.
The following examples aim to clarify the backend initialization logic in MLX:
.. code::
.. code:: python
host1 slots=1
host2 slots=1
# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend
world = mx.distributed.init(backend="mpi")
world2 = mx.distributed.init() # subsequent calls return the MPI backend!
When using MLX, it is very likely that you want to use 1 slot per host, ie one
process per host. The hostfile also needs to contain the current
host if you want to run on the local host. Passing the host file to
``mpirun`` is simply done using the ``--hostfile`` command line argument.
# Case 2: Initialize any backend
world = mx.distributed.init(backend="any") # equivalent to no arguments
world2 = mx.distributed.init() # same as above
# Case 3: Initialize both backends at the same time
world_mpi = mx.distributed.init(backend="mpi")
world_ring = mx.distributed.init(backend="ring")
world_any = mx.distributed.init() # same as MPI because it was initialized first!
Training Example
----------------
@@ -155,13 +166,179 @@ everything else remaining the same.
optimizer.update(model, grads)
return loss
Tuning All Reduce
-----------------
Utilizing ``nn.average_gradients``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
We are working on improving the performance of all reduce on MLX but for now
the two main things one can do to extract the most out of distributed training with MLX are:
Although the code example above works correctly; it performs one communication
per gradient. It is significantly more efficient to aggregate several gradients
together and perform fewer communication steps.
1. Perform a few large reductions instead of many small ones to improve
bandwidth and latency
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
connections between each host to improve bandwidth
This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks
almost identical to the example above:
.. code:: python
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = mlx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
Getting Started with MPI
------------------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. Launching distributed MLX programs that use MPI can be done with
``mpirun`` as expected. However, in the following examples we will be using
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
library.
The simplest possible usage is the following which, assuming the minimal
example in the beginning of this page, should result in:
.. code:: shell
$ mlx.launch --backend mpi -n 2 test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
print 4 etc.
Installing MPI
^^^^^^^^^^^^^^
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell
$ conda install conda-forge::openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
done automatically by ``mlx.launch``.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
$ # or simply
$ mlx.launch -n 2 test.py
Setting up Remote Hosts
^^^^^^^^^^^^^^^^^^^^^^^
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
Tuning MPI All Reduce
^^^^^^^^^^^^^^^^^^^^^
.. note::
For faster all reduce consider using the ring backend either with Thunderbolt
connections or over Ethernet.
Configure MPI to use N tcp connections between each host to improve bandwidth
by passing ``--mca btl_tcp_links N``.
Force MPI to use the most performant network interface by setting ``--mca
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
to use.
Getting Started with Ring
-------------------------
The ring backend does not depend on any third party library so it is always
available. It uses TCP sockets so the nodes need to be reachable via a network.
As the name suggests the nodes are connected in a ring which means that rank 1
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
and so on and so forth. As a result :func:`send` and :func:`recv` with
arbitrary sender and receiver is not supported in the ring backend.
Defining a Ring
^^^^^^^^^^^^^^^
The easiest way to define and use a ring is via a JSON hostfile and the
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
defines a hostname to ssh into to run commands on this node and one or more IPs
that this node will listen to for connections.
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
rank 0, ``hostname2`` rank 1 etc.
.. code:: json
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
]
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
node, run the script which will listen for connections in each of the provided
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
connection from ``123.123.123.4`` and so on and so forth.
Thunderbolt Ring
^^^^^^^^^^^^^^^^
Although the ring backend can have benefits over MPI even for Ethernet, its
main purpose is to use Thunderbolt rings for higher bandwidth communication.
Setting up such thunderbolt rings can be done manually, but is a relatively
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
utility as follows:
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
By default the script will attempt to discover the thunderbolt ring and provide
you with the commands to configure each node as well as the ``hostfile.json``
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
then ``--auto-setup`` can be used to configure them automatically.
To validate your connection without configuring anything
``mlx.distributed_config`` can also plot the ring using DOT format.
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
dot -Tpng ring.dot >ring.png
open ring.png
If you want to go through the process manually, the steps are as follows:
* Disable the thunderbolt bridge interface
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
corresponding to that cable in nodes ``i`` and ``i + 1``.
* Set up a unique subnetwork connecting the two nodes for the corresponding
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
``192.168.0.2`` respectively to the two nodes. For more details you can see
the commands prepared by the utility script.

View File

@@ -7,17 +7,17 @@ Exporting Functions
MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++).
front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation
<export>`.
Basics of Exporting
Basics of Exporting
-------------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
x = mx.array(1.0)
y = mx.array(1.0)
# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters.
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
# Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items())))
return model(x)
params = dict(tree_flatten(model.parameters()))
params = tree_flatten(model.parameters(), destination={})
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
# Ok
out, = imported_abs(mx.array(-1.0))
# Also ok
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
def fun(x, y=None):
constant = mx.array(3.0)
if y is not None:
x += y
x += y
return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter:
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
print(out)
In the above example the function constant data, (i.e. ``constant``), is only
saved once.
saved once.
Transformations with Imported Functions
---------------------------------------
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
# Prints: array(1, dtype=float32)
print(dfdx(x))
# Compile the imported function
# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use
Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++.

View File

@@ -107,6 +107,16 @@ same array:
>>> a
array([1, 2, 0], dtype=int32)
Note, unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> a[[0, 0]] = mx.array([4, 5])
The first element of ``a`` could be ``4`` or ``5``.
Transformations of functions which use in-place updates are allowed and work as
expected. For example:

View File

@@ -0,0 +1,105 @@
:orphan:
.. _usage_launch_distributed:
Launching Distributed Programs
==============================
.. currentmodule:: mlx.core.distributed
Installing the MLX python package provides a helper script ``mlx.launch`` that
can be used to run python scripts distributed on several nodes. It allows
launching using either the MPI backend or the ring backend. See the
:doc:`distributed docs <distributed>` for the different backends.
Usage
-----
The minimal usage example of ``mlx.launch`` is simply
.. code:: shell
mlx.launch --hosts ip1,ip2 my_script.py
or for testing on localhost
.. code:: shell
mlx.launch -n 2 my_script.py
The ``mlx.launch`` command connects to the provided host and launches the input
script on each host. It monitors each of the launched processes and terminates
the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
It also takes care of forwarding the output of each remote process to stdout
and stderr respectively.
Providing Hosts
^^^^^^^^^^^^^^^^
Hosts can be provided as command line arguments, like above, but the way that
allows to fully define a list of hosts is via a JSON hostfile. The hostfile has
a very simple schema. It is simply a list of objects that define each host via
a hostname to ssh to and a list of IPs to utilize for the communication.
.. code:: json
[
{"ssh": "hostname1", "ips": ["123.123.1.1", "123.123.2.1"]},
{"ssh": "hostname2", "ips": ["123.123.1.2", "123.123.2.2"]}
]
You can use ``mlx.distributed_config --over ethernet`` to create a hostfile
with IPs corresponding to the ``en0`` interface.
Setting up Remote Hosts
^^^^^^^^^^^^^^^^^^^^^^^^
In order to be able to launch the script on each host we need to be able to
connect via ssh. Moreover the input script and python binary need to be on each
host and on the same path. A good checklist to debug errors is the following:
* ``ssh hostname`` works without asking for password or host confirmation
* the python binary is available on all hosts at the same path. You can use
``mlx.launch --print-python`` to see what that path is.
* the script you want to run is available on all hosts at the same path
.. _mpi_specifics:
MPI Specifics
-------------
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
* The IPs in the hostfile are ignored
* The ssh connectivity requirement is stronger as every node needs to be able
to connect to every other node
* ``mpirun`` needs to be available on every node at the same path
Finally, one can pass arguments to ``mpirun`` using ``--mpi-arg``. For instance
to choose a specific interface for the byte-transfer-layer of MPI we can call
``mlx.launch`` as follows:
.. code:: shell
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
.. _ring_specifics:
Ring Specifics
--------------
The ring backend, which is also the default backend, can be explicitly selected
with the argument ``--backend ring``. The ring backend has some specific
requirements and arguments that are different to MPI:
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
ssh to a hostname that does not correspond to the IP we want to bind to we
have to provide a hostfile.
* ``--starting-port`` defines the port to bind to on the remote hosts.
Specifically rank 0 for the first IP will use this port and each subsequent
IP or rank will add 1 to this port.
* ``--connections-per-ip`` allows us to increase the number of connections
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
``mpirun``.

View File

@@ -10,7 +10,6 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED)
find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
@@ -21,6 +20,12 @@ execute_process(
OUTPUT_VARIABLE nanobind_ROOT)
find_package(nanobind CONFIG REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
# ----------------------------- Extensions -----------------------------
# Add library

View File

@@ -1,20 +1,15 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023-2025 Apple Inc.
#include <cassert>
#include <dlfcn.h>
#include <iostream>
#include <sstream>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/utils.h"
#include "axpby/axpby.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/cblas_new.h>
#endif
#ifdef _METAL_
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
@@ -22,6 +17,19 @@
namespace my_ext {
// A helper function to find the location of the current binary on disk.
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
std::string current_binary_dir() {
static std::string binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path().string();
}();
return binary_dir;
}
///////////////////////////////////////////////////////////////////////////////
// Operation Implementation
///////////////////////////////////////////////////////////////////////////////
@@ -76,136 +84,65 @@ void axpby_impl(
const mx::array& y,
mx::array& out,
float alpha_,
float beta_) {
// We only allocate memory when we are ready to fill the output
// malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
float beta_,
mx::Stream stream) {
out.set_data(mx::allocator::malloc(out.nbytes()));
// Collect input and output data pointers
const T* x_ptr = x.data<T>();
const T* y_ptr = y.data<T>();
T* out_ptr = out.data<T>();
// Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(y);
encoder.set_output_array(out);
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Launch the CPU kernel
encoder.dispatch([x_ptr = x.data<T>(),
y_ptr = y.data<T>(),
out_ptr = out.data<T>(),
size = out.size(),
shape = out.shape(),
x_strides = x.strides(),
y_strides = y.strides(),
alpha_,
beta_]() {
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides());
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < size; out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
});
}
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
void Axpby::eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == mx::float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_);
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_);
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_);
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_);
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
} else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
}
}
///////////////////////////////////////////////////////////////////////////////
// Primitive Accelerate Backend Implementation
///////////////////////////////////////////////////////////////////////////////
#ifdef ACCELERATE_NEW_LAPACK
template <typename T>
void axpby_impl_accelerate(
const mx::array& x,
const mx::array& y,
mx::array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
// This specialization requires both x and y be contiguous in the same mode
// i.e: corresponding linear indices in both point to corresponding elements
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, mx::CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
T* y_ptr = out.data<T>();
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Call the inplace accelerate operator
catlas_saxpby(
/* N = */ out.size(),
/* ALPHA = */ alpha,
/* X = */ x_ptr,
/* INCX = */ 1,
/* BETA = */ beta,
/* Y = */ y_ptr,
/* INCY = */ 1);
}
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == mx::float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
return;
}
// Fall back to common backend if specializations are not available
eval(inputs, outputs);
}
#else // Accelerate not available
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
eval(inputs, outputs);
}
#endif
///////////////////////////////////////////////////////////////////////////////
// Primitive Metal Backend Implementation
///////////////////////////////////////////////////////////////////////////////
@@ -217,7 +154,6 @@ void Axpby::eval_gpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
@@ -236,25 +172,24 @@ void Axpby::eval_gpu(
// Allocate output memory with strides based on specialization
if (contiguous_kernel) {
out.set_data(
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
mx::allocator::malloc(x.data_size() * out.itemsize()),
x.data_size(),
x.strides(),
x.flags());
} else {
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc(out.nbytes()));
}
// Resolve name of kernel (corresponds to axpby.metal)
std::ostringstream kname;
kname << "axpby_";
kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
std::string kname = "axpby_";
kname += (contiguous_kernel ? "contiguous_" : "general_");
kname += type_to_name(out);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Load the metal library
auto lib = d.get_library("mlx_ext", current_binary_dir());
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
auto kernel = d.get_kernel(kname, lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2025 Apple Inc.
#pragma once
@@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
const std::vector<mx::array>& inputs,
const std::vector<int>& axes) override;
/** Print the primitive. */
void print(std::ostream& os) override {
os << "Axpby";
/** The name of primitive. */
const char* name() const override {
return "Axpby";
}
/** Equivalence check **/
@@ -85,11 +85,6 @@ class Axpby : public mx::Primitive {
private:
float alpha_;
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs);
};
} // namespace my_ext

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2025 Apple Inc.
#include <metal_stdlib>

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.25
mlx>=0.21.0
nanobind==2.2.0
nanobind==2.4.0

View File

@@ -3,8 +3,10 @@ from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")
print(f"c shape: {c_cpu.shape}")
print(f"c dtype: {c_cpu.dtype}")
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")

View File

@@ -5,6 +5,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
@@ -19,6 +20,11 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
# Define MLX_VERSION only in the version.cpp file.
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
if(MSVC)
# Disable some MSVC warnings to speed up compilation.
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
@@ -43,5 +49,19 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
endif()
if(MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
else()
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
endif()
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
endif()

View File

@@ -4,12 +4,11 @@
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/scheduler.h"
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size, /* allow_swap */ true);
auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
@@ -22,45 +21,4 @@ void free(Buffer buffer) {
allocator().free(buffer);
}
Buffer CommonAllocator::malloc(size_t size, bool) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
Buffer malloc_or_wait(size_t size) {
auto buffer = allocator().malloc(size);
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
scheduler::wait_for_one();
buffer = allocator().malloc(size);
}
// Try swapping if needed
if (size && !buffer.ptr()) {
buffer = allocator().malloc(size, /* allow_swap = */ true);
}
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace mlx::core::allocator

View File

@@ -32,14 +32,10 @@ Buffer malloc(size_t size);
void free(Buffer buffer);
// Wait for running tasks to finish and free up memory
// if allocation fails
Buffer malloc_or_wait(size_t size);
class Allocator {
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
@@ -53,16 +49,4 @@ class Allocator {
Allocator& allocator();
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private:
CommonAllocator() = default;
friend Allocator& allocator();
};
} // namespace mlx::core::allocator

View File

@@ -56,6 +56,18 @@ std::vector<array> array::make_arrays(
return outputs;
}
array array::unsafe_weak_copy(const array& other) {
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
cpy.set_data(
other.buffer(),
other.data_size(),
other.strides(),
other.flags(),
[](auto) {});
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
return cpy;
}
array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>(
Shape{static_cast<ShapeElem>(data.size())},
@@ -76,35 +88,27 @@ array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
set_data(data, deleter);
}
array::array(
allocator::Buffer data,
Shape shape,
Dtype dtype,
Strides strides,
size_t data_size,
Flags flags,
Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, data_size, std::move(strides), flags, deleter);
}
void array::detach() {
array_desc_->primitive = nullptr;
for (auto& s : array_desc_->siblings) {
s.array_desc_->primitive = nullptr;
}
for (auto& s : array_desc_->siblings) {
s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear();
s.array_desc_->position = 0;
s.array_desc_->primitive = nullptr;
}
array_desc_->inputs.clear();
array_desc_->siblings.clear();
array_desc_->position = 0;
array_desc_->primitive = nullptr;
}
bool array::is_available() const {
if (status() == Status::available) {
return true;
} else if (status() == Status::evaluated && event().is_signaled()) {
} else if (
status() == Status::evaluated &&
(!event().valid() || event().is_signaled())) {
set_status(Status::available);
return true;
}
@@ -113,7 +117,10 @@ bool array::is_available() const {
void array::wait() {
if (!is_available()) {
event().wait();
if (event().valid()) {
event().wait();
detach_event();
}
set_status(Status::available);
}
}
@@ -174,34 +181,13 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
void array::move_shared_buffer(
array other,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
auto data_ptr = other.array_desc_->data_ptr;
other.array_desc_->data_ptr = nullptr;
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
}
void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
array::~array() {
if (array_desc_ == nullptr) {
return;
}
// Ignore arrays that might be detached during eval
if (status() == array::Status::scheduled) {
// Detached/detaching
if (array_desc_->primitive == nullptr) {
return;
}

View File

@@ -10,6 +10,7 @@
#include "mlx/allocator.h"
#include "mlx/dtype.h"
#include "mlx/event.h"
#include "mlx/small_vector.h"
namespace mlx::core {
@@ -18,8 +19,8 @@ class Primitive;
using Deleter = std::function<void(allocator::Buffer)>;
using ShapeElem = int32_t;
using Shape = std::vector<ShapeElem>;
using Strides = std::vector<int64_t>;
using Shape = SmallVector<ShapeElem>;
using Strides = SmallVector<int64_t>;
class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc
@@ -199,6 +200,13 @@ class array {
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs);
/**
* Get a new array that refers to the same data as the input but with a
* non-owning pointer to it. Note the array is detached from the graph and has
* no inputs, siblings or primitive.
*/
static array unsafe_weak_copy(const array& other);
/** A unique identifier for an array. */
std::uintptr_t id() const {
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
@@ -217,6 +225,10 @@ class array {
// Not copyable
Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete;
Data(Data&& o) : buffer(o.buffer), d(o.d) {
o.buffer = allocator::Buffer(nullptr);
o.d = [](allocator::Buffer) {};
}
~Data() {
d(buffer);
}
@@ -243,18 +255,6 @@ class array {
bool col_contiguous : 1;
};
/** Build an array from all the info held by the array description. Including
* the buffer, strides, flags.
*/
explicit array(
allocator::Buffer data,
Shape shape,
Dtype dtype,
Strides strides,
size_t data_size,
Flags flags,
Deleter deleter = allocator::free);
/** The array's primitive. */
Primitive& primitive() const {
return *(array_desc_->primitive);
@@ -344,11 +344,11 @@ class array {
return allocator::allocator().size(buffer());
}
// Return a copy of the shared pointer
// to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const {
// Return the shared pointer to the array::Data struct
const std::shared_ptr<Data>& data_shared_ptr() const {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
template <typename T>
T* data() {
@@ -361,15 +361,10 @@ class array {
}
enum Status {
// The ouptut of a computation which has not been scheduled.
// The output of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`.
unscheduled,
// The ouptut of a computation which has been scheduled but `eval_*` has
// not yet been called on the array's primitive. A possible
// status of `x` in `auto x = a + b; eval(x);`
scheduled,
// The array's `eval_*` function has been run, but the computation is not
// necessarily complete. The array will have memory allocated and if it is
// not a tracer then it will be detached from the graph.
@@ -406,6 +401,10 @@ class array {
array_desc_->event = std::move(e);
}
void detach_event() const {
array_desc_->event = Event{};
}
// Mark the array as a tracer array (true) or not.
void set_tracer(bool is_tracer) {
array_desc_->is_tracer = is_tracer;
@@ -431,15 +430,6 @@ class array {
void copy_shared_buffer(const array& other);
void move_shared_buffer(
array other,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
void move_shared_buffer(array other);
void overwrite_descriptor(const array& other) {
array_desc_ = other.array_desc_;
}

View File

@@ -1,6 +1,7 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp

View File

@@ -38,25 +38,20 @@ inline void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt,
bool donate_with_move = false) {
BinaryOpType bopt) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
switch (bopt) {
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b_donatable) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
allocator::malloc(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
@@ -64,14 +59,10 @@ inline void set_binary_op_output_data(
break;
case BinaryOpType::VectorScalar:
if (a_donatable) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
out.copy_shared_buffer(a);
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -79,20 +70,12 @@ inline void set_binary_op_output_data(
break;
case BinaryOpType::VectorVector:
if (a_donatable) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
out.copy_shared_buffer(a);
} else if (b_donatable) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -100,20 +83,12 @@ inline void set_binary_op_output_data(
break;
case BinaryOpType::General:
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
out.copy_shared_buffer(a);
} else if (
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
break;
}

View File

@@ -0,0 +1,24 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -1,10 +1,11 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void encode_wait(Event e);
void encode_signal(Event e);
void broadcast(const array& in, array& out);
} // namespace mlx::core

View File

@@ -0,0 +1,157 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cassert>
#include <functional>
#include <map>
namespace mlx::core {
template <typename T>
class BufferCache {
public:
BufferCache(
size_t page_size,
std::function<size_t(T*)> get_size,
std::function<void(T*)> free)
: page_size_(page_size),
get_size_(std::move(get_size)),
free_(std::move(free)) {}
~BufferCache() {
clear();
}
BufferCache(const BufferCache&) = delete;
BufferCache& operator=(const BufferCache&) = delete;
T* reuse_from_cache(size_t size) {
// Find the closest buffer in pool.
auto it = buffer_pool_.lower_bound(size);
if (it == buffer_pool_.end() ||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
return nullptr;
}
// Collect from the cache.
T* buf = it->second->buf;
pool_size_ -= it->first;
// Remove from record.
remove_from_list(it->second);
buffer_pool_.erase(it);
return buf;
}
void recycle_to_cache(T* buf) {
assert(buf);
// Add to cache.
BufferHolder* bh = new BufferHolder(buf);
add_at_head(bh);
size_t size = get_size_(buf);
pool_size_ += size;
buffer_pool_.emplace(size, bh);
}
int release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) {
return clear();
} else {
int n_release = 0;
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
// Release buffer.
size_t size = get_size_(tail_->buf);
total_bytes_freed += size;
free_(tail_->buf);
n_release++;
// Remove from record.
auto its = buffer_pool_.equal_range(size);
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
return el.second == tail_;
});
assert(it != buffer_pool_.end());
buffer_pool_.erase(it);
remove_from_list(tail_);
}
pool_size_ -= total_bytes_freed;
return n_release;
}
}
int clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
free_(holder->buf);
n_release++;
delete holder;
}
buffer_pool_.clear();
pool_size_ = 0;
head_ = nullptr;
tail_ = nullptr;
return n_release;
}
size_t cache_size() const {
return pool_size_;
}
size_t page_size() const {
return page_size_;
}
private:
struct BufferHolder {
public:
explicit BufferHolder(T* buf_) : buf(buf_) {}
BufferHolder* prev{nullptr};
BufferHolder* next{nullptr};
T* buf;
};
void add_at_head(BufferHolder* to_add) {
if (!head_) {
head_ = to_add;
tail_ = to_add;
} else {
head_->prev = to_add;
to_add->next = head_;
head_ = to_add;
}
}
void remove_from_list(BufferHolder* to_remove) {
if (to_remove->prev && to_remove->next) { // if middle
to_remove->prev->next = to_remove->next;
to_remove->next->prev = to_remove->prev;
} else if (to_remove->prev && to_remove == tail_) { // if tail
tail_ = to_remove->prev;
tail_->next = nullptr;
} else if (to_remove == head_ && to_remove->next) { // if head
head_ = to_remove->next;
head_->prev = nullptr;
} else if (to_remove == head_ && to_remove == tail_) { // if only element
head_ = nullptr;
tail_ = nullptr;
}
delete to_remove;
}
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_{nullptr};
BufferHolder* tail_{nullptr};
size_t pool_size_{0};
const size_t page_size_;
std::function<size_t(T*)> get_size_;
std::function<void(T*)> free_;
};
} // namespace mlx::core

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
@@ -39,24 +40,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
// rely on data_size anyway.
size_t data_size = out.size();
return move_or_copy(in, out, strides_, flags, data_size, offset_);
}
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
move_or_copy(in, out, strides, flags, in.data_size());
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
@@ -69,7 +53,7 @@ void BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {
void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
move_or_copy(inputs[0], out);
out.copy_shared_buffer(inputs[0]);
}
void CustomTransforms::eval(
@@ -78,7 +62,7 @@ void CustomTransforms::eval(
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
move_or_copy(inputs[j], outputs[i]);
outputs[i].copy_shared_buffer(inputs[j]);
}
}
@@ -87,7 +71,7 @@ void Depends::eval(
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
move_or_copy(inputs[i], outputs[i]);
outputs[i].copy_shared_buffer(inputs[i]);
}
}
@@ -98,12 +82,12 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
for (auto ax : axes_) {
strides.insert(strides.begin() + ax, 1);
}
move_or_copy(in, out, strides, in.flags(), in.data_size());
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
double numel = 1;
for (auto ax : axes_) {
@@ -210,7 +194,7 @@ void shared_buffer_reshape(
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
}
move_or_copy(in, out, out_strides, flags, in.data_size());
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
void Split::eval(
@@ -276,12 +260,12 @@ void Squeeze::eval(const std::vector<array>& inputs, array& out) {
strides.push_back(in.strides(i));
}
}
move_or_copy(in, out, strides, in.flags(), in.data_size());
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
move_or_copy(inputs[0], out);
out.copy_shared_buffer(inputs[0]);
}
void Transpose::eval(const std::vector<array>& inputs, array& out) {
@@ -315,7 +299,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
b_stride *= out.shape(ri);
}
}
move_or_copy(in, out, out_strides, flags, in.data_size());
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -1,8 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/backend/common/utils.h"
#include "mlx/utils.h"
namespace mlx::core {
@@ -15,6 +14,8 @@ void print_constant(std::ostream& os, const array& x) {
return print_float_constant<float16_t>(os, x);
case bfloat16:
return print_float_constant<bfloat16_t>(os, x);
case float64:
return print_float_constant<double>(os, x);
case complex64:
return print_complex_constant<complex64_t>(os, x);
case int8:
@@ -51,6 +52,8 @@ std::string get_type_string(Dtype d) {
return "float16_t";
case bfloat16:
return "bfloat16_t";
case float64:
return "double";
case complex64:
return "complex64_t";
case bool_:
@@ -79,55 +82,6 @@ std::string get_type_string(Dtype d) {
}
}
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids) {
NodeNamer namer;
std::ostringstream os;
std::ostringstream constant_hasher;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (auto& x : inputs) {
namer.get_name(x);
}
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (auto& a : tape) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
return os.str();
}
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const Shape& shape) {
@@ -159,10 +113,8 @@ bool compiled_check_contiguity(
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers /* = false */) {
const std::function<bool(size_t)>& is_constant,
bool contiguous) {
if (contiguous) {
int o = 0;
Strides strides;
@@ -176,13 +128,8 @@ void compiled_allocate_outputs(
// - Donatable
// - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o++].move_shared_buffer(in);
} else {
outputs[o++].copy_shared_buffer(in);
}
in.is_donatable() && is_constant(i)) {
outputs[o++].copy_shared_buffer(in);
}
// Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) {
@@ -193,7 +140,7 @@ void compiled_allocate_outputs(
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
allocator::malloc(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
@@ -209,21 +156,86 @@ void compiled_allocate_outputs(
// - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o].move_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
} else {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
}
is_constant(i)) {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
o++;
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
}
}
}
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
const std::vector<array>& inputs,
const array& out,
const std::function<bool(size_t)>& is_constant) {
const Shape& shape = out.shape();
bool contiguous = compiled_check_contiguity(inputs, shape);
if (contiguous) {
return {true, shape, {}};
}
std::vector<Strides> strides_vec{out.strides()};
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants.
if (is_constant(i)) {
continue;
}
// Skip scalar inputs.
const auto& x = inputs[i];
if (is_scalar(x)) {
continue;
}
// Broadcast the inputs to the output shape.
Strides xstrides;
size_t j = 0;
for (; j < shape.size() - x.ndim(); ++j) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides_vec.push_back(std::move(xstrides));
}
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
}
bool compiled_use_large_index(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
bool contiguous) {
if (contiguous) {
size_t max_size = 0;
for (const auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
return max_size > UINT32_MAX;
} else {
size_t max_size = 0;
for (const auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
return max_size > UINT32_MAX;
}
}
} // namespace mlx::core

View File

@@ -1,9 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <functional>
#include <iomanip>
#include <sstream>
#include <unordered_set>
#include "mlx/array.h"
#include "mlx/primitives.h"
@@ -14,19 +13,17 @@ inline bool is_static_cast(const Primitive& p) {
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
}
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids);
std::string get_type_string(Dtype d);
template <typename T>
void print_float_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision();
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
<< x.item<T>() << std::setprecision(old_precision);
if constexpr (std::is_same_v<T, double>) {
os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
} else {
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
}
os << x.item<T>() << std::setprecision(old_precision);
}
template <typename T>
@@ -60,9 +57,19 @@ bool compiled_check_contiguity(
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers = false);
const std::function<bool(size_t)>& is_constant,
bool contiguous);
// Collapse contiguous dims ignoring scalars and constants.
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
const std::vector<array>& inputs,
const array& out,
const std::function<bool(size_t)>& is_constant);
// Return whether the kernel should use large index.
bool compiled_use_large_index(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
bool contiguous);
} // namespace mlx::core

View File

@@ -2,7 +2,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
@@ -22,4 +22,25 @@ enum class CopyType {
GeneralGeneral
};
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
return true;
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
return false;
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
return false;
}
}
} // namespace mlx::core

View File

@@ -99,7 +99,11 @@ inline std::pair<int, int> decompose_hadamard(int n) {
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
}
}
if (n > (1 << 26)) {
throw std::invalid_argument(
"[hadamard] Only supports n = m*2^k where k <= 26");
}
return {n, m};
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -3,7 +3,8 @@
#include <algorithm>
#include <utility>
#include "mlx/backend/common/load.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
namespace {
@@ -26,26 +27,31 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
namespace mlx::core {
void load(
array& out,
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianness_) {
reader->read(out.data<char>(), out.nbytes(), offset);
if (swap_endianness_) {
switch (out.itemsize()) {
case 2:
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
break;
case 4:
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
break;
case 8:
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
break;
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
auto read_task = [out_ptr = out.data<char>(),
size = out.size(),
itemsize = out.itemsize(),
offset = offset_,
reader = reader_,
swap_endianness_ = swap_endianness_]() mutable {
reader->read(out_ptr, size * itemsize, offset);
if (swap_endianness_) {
switch (itemsize) {
case 2:
swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 4:
swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 8:
swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
}
}
}
};
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); });
}
} // namespace mlx::core

View File

@@ -1,14 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/array.h"
#include "mlx/io/load.h"
namespace mlx::core {
void load(
array& out,
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianess);
} // namespace mlx::core

View File

@@ -0,0 +1,67 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/common/utils.h"
#include "mlx/utils.h"
#include <sstream>
namespace mlx::core {
inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a,
const array& b) {
if (a.ndim() == 2) {
return {{1}, {0}, {0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
auto a_batch_strides = batch_strides[0];
auto b_batch_strides = batch_strides[1];
if (batch_shape.empty()) {
batch_shape.push_back(1);
a_batch_strides.push_back(0);
b_batch_strides.push_back(0);
}
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
}
inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) {
if (a.ndim() == 2) {
return {{1}, {0}, {0}, {0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];
auto C_batch_stride = batch_strides[2];
if (batch_shape.empty()) {
batch_shape.push_back(1);
A_batch_stride.push_back(0);
B_batch_stride.push_back(0);
C_batch_stride.push_back(0);
}
return std::make_tuple(
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
}
} // namespace mlx::core

View File

@@ -5,11 +5,9 @@
namespace mlx::core {
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
Shape shape,
Strides strides,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
shape.erase(shape.begin() + a);
@@ -19,6 +17,15 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
return std::make_pair(shape, strides);
}
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
return shapes_without_reduction_axes(
std::move(shape), std::move(strides), axes);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&

View File

@@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
Shape shape,
Strides strides,
const std::vector<int>& axes);
} // namespace mlx::core

View File

@@ -14,6 +14,10 @@ std::tuple<int64_t, Strides> prepare_slice(
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
}
// Normalize the offset
if (data_offset < 0) {
data_offset += in.data_size();
}
return std::make_tuple(data_offset, inp_strides);
}
@@ -32,7 +36,7 @@ void shared_buffer_slice(
flags.col_contiguous = is_col_contiguous;
flags.contiguous = (no_bsx_size == data_size);
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
void slice(
@@ -54,9 +58,10 @@ void slice(
data_end += end_idx * in.strides()[i];
}
}
// data_end can be -1
size_t data_size =
data_end < 0 ? (data_offset - data_end) : (data_end - data_offset);
if (data_end < 0) {
data_end += in.data_size();
}
size_t data_size = (data_end - data_offset);
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}

View File

@@ -36,15 +36,10 @@ inline void set_ternary_op_output_data(
const array& b,
const array& c,
array& out,
TernaryOpType topt,
bool donate_with_move = false) {
auto maybe_donate = [&out, donate_with_move](const array& x) {
TernaryOpType topt) {
auto maybe_donate = [&out](const array& x) {
if (is_donatable(x, out)) {
if (donate_with_move) {
out.move_shared_buffer(x);
} else {
out.copy_shared_buffer(x);
}
out.copy_shared_buffer(x);
return true;
}
return false;
@@ -53,12 +48,12 @@ inline void set_ternary_op_output_data(
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data(
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
allocator::malloc(out.itemsize() * b.data_size()),
b.data_size(),
b.strides(),
b.flags());
@@ -69,7 +64,7 @@ inline void set_ternary_op_output_data(
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
break;
}

View File

@@ -0,0 +1,26 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
inline void set_unary_output_data(const array& in, array& out) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
}
} // namespace mlx::core

View File

@@ -1,29 +1,20 @@
// Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void move_or_copy(const array& in, array& out) {
if (in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.copy_shared_buffer(in);
}
}
void move_or_copy(
const array& in,
array& out,
const Strides& strides,
array::Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
if (in.is_donatable()) {
out.move_shared_buffer(in, strides, flags, data_size, offset);
} else {
out.copy_shared_buffer(in, strides, flags, data_size, offset);
}
std::filesystem::path current_binary_dir() {
static std::filesystem::path binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path();
}();
return binary_dir;
}
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
@@ -123,4 +114,118 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
}
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
int pows[3] = {0, 0, 0};
int sum = 0;
while (true) {
int presum = sum;
// Check all the pows
if (dim0 >= (1 << (pows[0] + 1))) {
pows[0]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim1 >= (1 << (pows[1] + 1))) {
pows[1]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim2 >= (1 << (pows[2] + 1))) {
pows[2]++;
sum++;
}
if (sum == presum || sum == pow2) {
break;
}
}
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
}
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
// Dims with strides of 0 are ignored as they
// correspond to broadcasted dimensions
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape.");
}
if (grid_y > grid_x) {
std::swap(grid_x, grid_y);
}
return std::make_tuple(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
Dims get_2d_grid_dims_common(
const Shape& shape,
const Strides& strides,
size_t divisor) {
// Compute the 2d grid dimensions such that the total size of the grid is
// divided by divisor.
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
// No need to add this shape we can just remove it from the divisor.
if (divisor % shape[i] == 0) {
divisor /= shape[i];
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
if (divisor > 1) {
if (grid_x % divisor == 0) {
grid_x /= divisor;
divisor = 1;
} else if (grid_y % divisor == 0) {
grid_y /= divisor;
divisor = 1;
}
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape.");
}
if (grid_y > grid_x) {
std::swap(grid_x, grid_y);
}
if (divisor > 1) {
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
}
return std::make_tuple(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
auto gx = (dim0 + bx - 1) / bx;
auto gy = (dim1 + by - 1) / by;
auto gz = (dim2 + bz - 1) / bz;
return std::make_pair(
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
}
} // namespace mlx::core

View File

@@ -2,12 +2,17 @@
#pragma once
#include <filesystem>
#include <tuple>
#include <vector>
#include "mlx/array.h"
namespace mlx::core {
// Return the directory that contains current shared library.
std::filesystem::path current_binary_dir();
inline int64_t
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
int64_t loc = 0;
@@ -70,6 +75,31 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
const array& a,
int64_t size_cap = std::numeric_limits<int32_t>::max());
// Compute the thread block dimensions which fit the given
// input dimensions.
// - The thread block dimensions will be powers of two
// - The thread block size will be less than 2^pow2
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
// Computes a 2D grid where each element is < UINT_MAX
// Assumes:
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
// - shape and strides correspond to a contiguous (no holes) but
// possibly broadcasted array
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
// Same as above but we do an implicit division with divisor.
// Basically, equivalent to factorizing
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
Dims get_2d_grid_dims_common(
const Shape& shape,
const Strides& strides,
size_t divisor);
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
struct ContiguousIterator {
inline void step() {
int dims = shape_.size();
@@ -159,19 +189,17 @@ inline bool is_donatable(const array& in, const array& out) {
in.buffer_size() <= out.nbytes() + donation_extra;
}
void move_or_copy(const array& in, array& out);
void move_or_copy(
const array& in,
array& out,
const Strides& strides,
array::Flags flags,
size_t data_size,
size_t offset = 0);
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
void shared_buffer_reshape(
const array& in,
const Strides& out_strides,
array& out);
template <typename T>
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));
return vec;
}
} // namespace mlx::core

View File

@@ -40,11 +40,15 @@ add_dependencies(mlx cpu_compiled_preamble)
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
@@ -56,6 +60,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
@@ -65,13 +70,14 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
if(MLX_BUILD_ACCELERATE)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
endif()
if(IOS)

View File

@@ -2,76 +2,27 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/cpu/encoder.h"
namespace mlx::core {
namespace {
template <typename T>
void arange(T start, T next, array& out, size_t size) {
void arange(T start, T next, array& out, size_t size, Stream stream) {
auto ptr = out.data<T>();
auto step_size = next - start;
for (int i = 0; i < size; ++i) {
ptr[i] = start;
start += step_size;
}
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch([ptr, start, step_size, size]() mutable {
for (int i = 0; i < size; ++i) {
ptr[i] = start;
start += step_size;
}
});
}
} // namespace
void arange(
const std::vector<array>& inputs,
array& out,
double start,
double step) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (out.dtype()) {
case bool_:
throw std::runtime_error("Bool type unsupported for arange.");
break;
case uint8:
arange<uint8_t>(start, start + step, out, out.size());
break;
case uint16:
arange<uint16_t>(start, start + step, out, out.size());
break;
case uint32:
arange<uint32_t>(start, start + step, out, out.size());
break;
case uint64:
arange<uint64_t>(start, start + step, out, out.size());
break;
case int8:
arange<int8_t>(start, start + step, out, out.size());
break;
case int16:
arange<int16_t>(start, start + step, out, out.size());
break;
case int32:
arange<int32_t>(start, start + step, out, out.size());
break;
case int64:
arange<int64_t>(start, start + step, out, out.size());
break;
case float16:
arange<float16_t>(start, start + step, out, out.size());
break;
case float32:
arange<float>(start, start + step, out, out.size());
break;
case float64:
arange<double>(start, start + step, out, out.size());
break;
case bfloat16:
arange<bfloat16_t>(start, start + step, out, out.size());
break;
case complex64:
arange<complex64_t>(start, start + step, out, out.size());
break;
}
}
} // namespace mlx::core

View File

@@ -3,6 +3,7 @@
#include <cassert>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -13,19 +14,20 @@ template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis];
Strides strides = in.strides();
Shape shape = in.shape();
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
Strides strides = remove_index(in.strides(), axis);
Shape shape = remove_index(in.shape(), axis);
auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>();
for (uint32_t i = 0; i < out.size(); ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto in_ptr = in.data<InT>() + loc;
auto local_in_ptr = in_ptr + loc;
uint32_t ind_v = 0;
InT v = (*in_ptr);
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) {
op(j, (*in_ptr), &ind_v, &v);
InT v = (*local_in_ptr);
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
op(j, (*local_in_ptr), &ind_v, &v);
}
out.data<uint32_t>()[i] = ind_v;
out_ptr[i] = ind_v;
}
}
@@ -64,52 +66,59 @@ void arg_reduce_dispatch(
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (in.dtype()) {
case bool_:
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
break;
case uint8:
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
break;
case uint16:
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
break;
case uint32:
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
break;
case uint64:
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
break;
case int8:
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
break;
case int16:
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
break;
case int32:
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
break;
case int64:
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
break;
case float16:
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
break;
case float32:
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
break;
case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break;
case float64:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
reduce_type_ = reduce_type_,
axis_ = axis_]() mutable {
switch (in.dtype()) {
case bool_:
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
break;
case uint8:
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
break;
case uint16:
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
break;
case uint32:
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
break;
case uint64:
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
break;
case int8:
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
break;
case int16:
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
break;
case int32:
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
break;
case int64:
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
break;
case float16:
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
break;
case float32:
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
break;
case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break;
case float64:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break;
}
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/available.h"
namespace mlx::core::cpu {
bool is_available() {
return true;
}
} // namespace mlx::core::cpu

View File

@@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::cpu {
bool is_available();
} // namespace mlx::core::cpu

View File

@@ -8,6 +8,7 @@
#include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/binary_two.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -16,51 +17,221 @@ namespace mlx::core {
namespace {
template <typename Op>
void comparison_op(const array& a, const array& b, array& out, Op op) {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool>(a, b, out, op);
break;
case uint8:
binary_op<uint8_t, bool>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t, bool>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t, bool>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t, bool>(a, b, out, op);
break;
case int8:
binary_op<int8_t, bool>(a, b, out, op);
break;
case int16:
binary_op<int16_t, bool>(a, b, out, op);
break;
case int32:
binary_op<int32_t, bool>(a, b, out, op);
break;
case int64:
binary_op<int64_t, bool>(a, b, out, op);
break;
case float16:
binary_op<float16_t, bool>(a, b, out, op);
break;
case float32:
binary_op<float, bool>(a, b, out, op);
break;
case float64:
binary_op<double, bool>(a, b, out, op);
break;
case bfloat16:
binary_op<bfloat16_t, bool>(a, b, out, op);
break;
case complex64:
binary_op<complex64_t, bool>(a, b, out, op);
break;
}
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports floating point types.");
}
});
}
template <typename Op>
void binary_int(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
}
} // namespace
@@ -69,7 +240,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Add());
binary(a, b, out, detail::Add(), stream());
}
void DivMod::eval_cpu(
@@ -78,70 +249,89 @@ void DivMod::eval_cpu(
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto integral_op = [](auto x, auto y) {
return std::make_pair(x / y, x % y);
};
auto float_op = [](auto x, auto y) {
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
};
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, integral_op);
case uint8:
binary_op<uint8_t>(a, b, outputs, integral_op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, integral_op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, integral_op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, integral_op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, integral_op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, integral_op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, integral_op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, integral_op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, float_op);
break;
case float32:
binary_op<float>(a, b, outputs, float_op);
break;
case float64:
binary_op<double>(a, b, outputs, float_op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, float_op);
break;
case complex64:
// Should never get here
throw std::runtime_error("[DivMod] Complex type not supported");
break;
}
auto bopt = get_binary_op_type(a, b);
auto& out_a = outputs[0];
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out_a = array::unsafe_weak_copy(out_a),
out_b = array::unsafe_weak_copy(out_b),
bopt]() mutable {
auto integral_op = [](auto x, auto y) {
return std::make_pair(x / y, x % y);
};
auto float_op = [](auto x, auto y) {
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
};
switch (out_a.dtype()) {
case bool_:
binary_op<bool>(a, b, out_a, out_b, integral_op, bopt);
case uint8:
binary_op<uint8_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case uint16:
binary_op<uint16_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case uint32:
binary_op<uint32_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case uint64:
binary_op<uint64_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int8:
binary_op<int8_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int16:
binary_op<int16_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int32:
binary_op<int32_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int64:
binary_op<int64_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case float16:
binary_op<float16_t>(a, b, out_a, out_b, float_op, bopt);
break;
case float32:
binary_op<float>(a, b, out_a, out_b, float_op, bopt);
break;
case float64:
binary_op<double>(a, b, out_a, out_b, float_op, bopt);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out_a, out_b, float_op, bopt);
break;
case complex64:
// Should never get here
throw std::runtime_error("[DivMod] Complex type not supported");
break;
}
});
}
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Divide());
binary(a, b, out, detail::Divide(), stream());
}
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Remainder());
binary(a, b, out, detail::Remainder(), stream());
}
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -149,181 +339,143 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& a = inputs[0];
auto& b = inputs[1];
if (equal_nan_) {
switch (a.dtype()) {
case float16:
binary_op<float16_t, bool>(a, b, out, detail::NaNEqual());
break;
case float32:
binary_op<float, bool>(a, b, out, detail::NaNEqual());
break;
case float64:
binary_op<double, bool>(a, b, out, detail::NaNEqual());
break;
case bfloat16:
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
break;
case complex64:
binary_op<complex64_t, bool>(a, b, out, detail::NaNEqual());
break;
default:
throw std::runtime_error(
"[NanEqual::eval_cpu] Only for floating point types.");
}
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case float16:
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[NanEqual::eval_cpu] Only for floating point types.");
}
});
} else {
comparison_op(a, b, out, detail::Equal());
comparison_op(a, b, out, detail::Equal(), stream());
}
}
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Greater());
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
}
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
}
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Less());
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
}
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
}
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
switch (out.dtype()) {
case float16:
binary_op<float16_t>(a, b, out, detail::LogAddExp());
break;
case float32:
binary_op<float>(a, b, out, detail::LogAddExp());
break;
case float64:
binary_op<double>(a, b, out, detail::LogAddExp());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
break;
default:
throw std::runtime_error(
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
}
binary_float(a, b, out, detail::LogAddExp(), stream());
}
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd());
binary(in1, in2, out, detail::LogicalAnd(), stream());
}
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr());
binary(in1, in2, out, detail::LogicalOr(), stream());
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Maximum());
binary(a, b, out, detail::Maximum(), stream());
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Minimum());
binary(a, b, out, detail::Minimum(), stream());
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Multiply());
binary(a, b, out, detail::Multiply(), stream());
}
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
}
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Power());
binary(a, b, out, detail::Power(), stream());
}
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Subtract());
binary(a, b, out, detail::Subtract(), stream());
}
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) {
case BitwiseBinary::And:
dispatch_type(detail::BitwiseAnd());
binary_int(a, b, out, detail::BitwiseAnd(), stream());
break;
case BitwiseBinary::Or:
dispatch_type(detail::BitwiseOr());
binary_int(a, b, out, detail::BitwiseOr(), stream());
break;
case BitwiseBinary::Xor:
dispatch_type(detail::BitwiseXor());
binary_int(a, b, out, detail::BitwiseXor(), stream());
break;
case BitwiseBinary::LeftShift:
dispatch_type(detail::LeftShift());
binary_int(a, b, out, detail::LeftShift(), stream());
break;
case BitwiseBinary::RightShift:
dispatch_type(detail::RightShift());
binary_int(a, b, out, detail::RightShift(), stream());
break;
}
}
@@ -332,23 +484,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
switch (out.dtype()) {
case float16:
binary_op<float16_t>(a, b, out, detail::ArcTan2());
break;
case float32:
binary_op<float>(a, b, out, detail::ArcTan2());
break;
case float64:
binary_op<double>(a, b, out, detail::ArcTan2());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
break;
default:
throw std::runtime_error(
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
}
binary_float(a, b, out, detail::ArcTan2(), stream());
}
} // namespace mlx::core

View File

@@ -3,7 +3,6 @@
#pragma once
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h"
@@ -14,22 +13,18 @@ namespace mlx::core {
template <typename Op>
struct VectorScalar {
Op op;
VectorScalar(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *b;
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
simd::store(dst, Op{}(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
dst += N;
a += N;
size -= N;
}
while (size-- > 0) {
*dst = op(*a, scalar);
*dst = Op{}(*a, scalar);
dst++;
a++;
}
@@ -38,22 +33,18 @@ struct VectorScalar {
template <typename Op>
struct ScalarVector {
Op op;
ScalarVector(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *a;
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
simd::store(dst, Op{}(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
dst += N;
b += N;
size -= N;
}
while (size-- > 0) {
*dst = op(scalar, *b);
*dst = Op{}(scalar, *b);
dst++;
b++;
}
@@ -62,22 +53,18 @@ struct ScalarVector {
template <typename Op>
struct VectorVector {
Op op;
VectorVector(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::load<T, N>(b)));
simd::store(dst, Op{}(simd::load<T, N>(a), simd::load<T, N>(b)));
dst += N;
a += N;
b += N;
size -= N;
}
while (size-- > 0) {
*dst = op(*a, *b);
*dst = Op{}(*a, *b);
dst++;
a++;
b++;
@@ -90,7 +77,6 @@ void binary_op_dims(
const T* a,
const T* b,
U* out,
Op op,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
@@ -104,12 +90,12 @@ void binary_op_dims(
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
binary_op_dims<T, U, Op, D - 1, Strided>(
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
a, b, out, shape, a_strides, b_strides, out_strides, axis + 1);
} else {
if constexpr (Strided) {
op(a, b, out, stride_out);
Op{}(a, b, out, stride_out);
} else {
*out = op(*a, *b);
*out = Op{}(*a, *b);
}
}
out += stride_out;
@@ -120,66 +106,38 @@ void binary_op_dims(
template <typename T, typename U, bool Strided, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op,
const T* a,
const T* b,
U* out,
int dim,
int size,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>();
switch (dim) {
case 1:
binary_op_dims<T, U, Op, 1, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
a, b, out, shape, a_strides, b_strides, out_strides, 0);
return;
case 2:
binary_op_dims<T, U, Op, 2, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
a, b, out, shape, a_strides, b_strides, out_strides, 0);
return;
case 3:
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
a, b, out, shape, a_strides, b_strides, out_strides, 0);
return;
}
ContiguousIterator a_it(shape, a_strides, dim - 3);
ContiguousIterator b_it(shape, b_strides, dim - 3);
auto stride = out_strides[dim - 4];
for (int64_t elem = 0; elem < a.size(); elem += stride) {
for (int64_t elem = 0; elem < size; elem += stride) {
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_ptr + elem,
op,
a + a_it.loc,
b + b_it.loc,
out + elem,
shape,
a_strides,
b_strides,
@@ -191,40 +149,41 @@ void binary_op_dispatch_dims(
}
template <typename T, typename U, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
// The full computation is scalar scalar so call the base op once
auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>();
auto out_ptr = out.data<U>();
if (bopt == BinaryOpType::ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
*out_ptr = Op{}(*a_ptr, *b_ptr);
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, a.size());
return;
}
// General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out.strides()});
const auto& a_strides = new_strides[0];
const auto& b_strides = new_strides[1];
const auto& strides = new_strides[2];
auto& a_strides = new_strides[0];
auto& b_strides = new_strides[1];
auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
@@ -248,7 +207,8 @@ void binary_op(const array& a, const array& b, array& out, Op op) {
auto ndim = new_shape.size();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
// Case 1: LxM and FxM where L and F are broadcastable and M is row
// contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
@@ -275,99 +235,59 @@ void binary_op(const array& a, const array& b, array& out, Op op) {
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U, true>(
a,
b,
out,
VectorVector{op},
binary_op_dispatch_dims<T, U, true, VectorVector<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U, true>(
a,
b,
out,
VectorScalar{op},
binary_op_dispatch_dims<T, U, true, VectorScalar<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U, true>(
a,
b,
out,
ScalarVector{op},
binary_op_dispatch_dims<T, U, true, ScalarVector<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break;
default:
binary_op_dispatch_dims<T, U, false>(
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
binary_op_dispatch_dims<T, U, false, Op>(
a_ptr,
b_ptr,
out_ptr,
dim,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break;
}
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
binary_op<T, T>(a, b, out, op);
}
template <typename Op>
void binary(const array& a, const array& b, array& out, Op op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
break;
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
case float16:
binary_op<float16_t>(a, b, out, op);
break;
case float32:
binary_op<float>(a, b, out, op);
break;
case float64:
binary_op<double>(a, b, out, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, op);
break;
case complex64:
binary_op<complex64_t>(a, b, out, op);
break;
}
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
binary_op<T, T, Op>(a, b, out, bopt);
}
} // namespace mlx::core

View File

@@ -58,14 +58,14 @@ void binary_op_dispatch_dims(
Op op) {
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out_a.strides()});
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_a_ptr = out_a.data<U>();
U* out_b_ptr = out_b.data<U>();
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
int ndim = shape.size();
switch (ndim) {
case 1:
@@ -120,14 +120,10 @@ template <typename T, typename U = T, typename Op>
void binary_op(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
auto bopt = get_binary_op_type(a, b);
auto& out_a = outputs[0];
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
array& out_a,
array& out_b,
Op op,
BinaryOpType bopt) {
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::General) {
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
@@ -141,14 +137,14 @@ void binary_op(
if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
} else if (bopt == BinaryOpType::ScalarVector) {
for (size_t i = 0; i < b.size(); ++i) {
for (size_t i = 0; i < b.data_size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
b_ptr++;
}
} else if (bopt == BinaryOpType::VectorScalar) {
for (size_t i = 0; i < a.size(); ++i) {
for (size_t i = 0; i < a.data_size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
@@ -165,58 +161,6 @@ void binary_op(
}
}
template <typename Op>
void binary(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, op);
break;
case uint8:
binary_op<uint8_t>(a, b, outputs, op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, op);
break;
case float32:
binary_op<float>(a, b, outputs, op);
break;
case float64:
binary_op<double>(a, b, outputs, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, op);
break;
case complex64:
binary_op<complex64_t>(a, b, outputs, op);
break;
}
}
} // namespace
} // namespace mlx::core

View File

@@ -2,13 +2,15 @@
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
void cholesky_impl(const array& a, array& factor, bool upper) {
template <typename T>
void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
// Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric:
// (A)ᵀ = A
@@ -16,59 +18,68 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
// triangular matrix, so uplo is the opposite of what we would expect from
// upper
char uplo = (upper) ? 'L' : 'U';
// The decomposition is computed in place, so just copy the input to the
// output.
copy(
copy_cpu(
a,
factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(factor);
encoder.dispatch([matrix = factor.data<T>(),
upper,
N = a.shape(-1),
size = a.size()]() mutable {
char uplo = (upper) ? 'L' : 'U';
size_t num_matrices = size / (N * N);
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info;
potrf<T>(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
float* matrix = factor.data<float>();
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info;
MLX_LAPACK_FUNC(spotrf)
(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how
// to catch errors from the implementation we should throw.
if (info < 0) {
std::stringstream msg;
msg << "[cholesky] Cholesky decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
// Zero out the upper/lower triangle while advancing the pointer to the
// next matrix at the same time.
for (int row = 0; row < N; row++) {
if (upper) {
std::fill(matrix, matrix + row, 0);
} else {
std::fill(matrix + row + 1, matrix + N, 0);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how
// to catch errors from the implementation we should throw.
if (info < 0) {
std::stringstream msg;
msg << "[Cholesky::eval_cpu] Cholesky decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
// Zero out the upper/lower triangle while advancing the pointer to the
// next matrix at the same time.
for (int row = 0; row < N; row++) {
if (upper) {
std::fill(matrix, matrix + row, 0);
} else {
std::fill(matrix + row + 1, matrix + N, 0);
}
matrix += N;
}
matrix += N;
}
}
});
}
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Cholesky::eval] only supports float32.");
switch (inputs[0].dtype()) {
case float32:
cholesky_impl<float>(inputs[0], output, upper_, stream());
break;
case float64:
cholesky_impl<double>(inputs[0], output, upper_, stream());
break;
default:
throw std::runtime_error(
"[Cholesky::eval_cpu] only supports float32 or float64.");
}
cholesky_impl(inputs[0], output, upper_);
}
} // namespace mlx::core

View File

@@ -11,9 +11,11 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cpu/compiled_preamble.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"
#include "mlx/version.h"
namespace mlx::core {
@@ -39,7 +41,10 @@ struct CompilerCache {
std::shared_mutex mtx;
};
static CompilerCache cache{};
static CompilerCache& cache() {
static CompilerCache cache_;
return cache_;
};
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
@@ -55,14 +60,16 @@ void* compile(
const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) {
{
std::shared_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
std::shared_lock lock(cache().mtx);
if (auto it = cache().kernels.find(kernel_name);
it != cache().kernels.end()) {
return it->second;
}
}
std::unique_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
std::unique_lock lock(cache().mtx);
if (auto it = cache().kernels.find(kernel_name);
it != cache().kernels.end()) {
return it->second;
}
std::string source_code = source_builder();
@@ -88,7 +95,11 @@ void* compile(
kernel_file_name = kernel_name;
}
auto output_dir = std::filesystem::temp_directory_path();
auto output_dir =
std::filesystem::temp_directory_path() / "mlx" / version() / "cpu";
if (!std::filesystem::exists(output_dir)) {
std::filesystem::create_directories(output_dir);
}
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
auto shared_lib_path = (output_dir / shared_lib_name).string();
@@ -119,10 +130,10 @@ void* compile(
}
// load library
cache.libs.emplace_back(shared_lib_path);
cache().libs.emplace_back(shared_lib_path);
// Load function
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
if (!fun) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function "
@@ -130,7 +141,7 @@ void* compile(
<< dlerror();
throw std::runtime_error(msg.str());
}
cache.kernels.insert({kernel_name, fun});
cache().kernels.insert({kernel_name, fun});
return fun;
}
@@ -140,18 +151,9 @@ inline void build_kernel(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids,
const std::function<bool(size_t)>& is_constant,
bool contiguous,
int ndim) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer;
#ifdef _MSC_VER
@@ -160,25 +162,28 @@ inline void build_kernel(
#endif
// Start the kernel
os << "void " << kernel_name << "(void** args) {" << std::endl;
os << "void " << kernel_name
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
// Add the input arguments
int cnt = 0;
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants from the input list
if (is_constant(x)) {
if (is_constant(i)) {
continue;
}
const auto& x = inputs[i];
auto& xname = namer.get_name(x);
auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
<< "];" << std::endl;
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< "];" << std::endl;
os << " const int64_t* " << xname << "_strides = strides["
<< strides_index++ << "];" << std::endl;
}
}
@@ -188,10 +193,8 @@ inline void build_kernel(
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl;
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
// Add output size
if (contiguous) {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
}
@@ -205,10 +208,11 @@ inline void build_kernel(
}
// Read the inputs in tmps
for (auto& x : inputs) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
auto& xname = namer.get_name(x);
if (is_constant(x)) {
if (is_constant(i)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x);
os << ";" << std::endl;
@@ -232,7 +236,7 @@ inline void build_kernel(
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else {
x.primitive().print(os);
os << x.primitive().name();
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
@@ -258,8 +262,9 @@ inline void build_kernel(
} else {
for (int d = ndim - 1; d >= 0; --d) {
// Update pointers
for (auto& x : inputs) {
if (is_constant(x) || is_scalar(x)) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
if (is_constant(i) || is_scalar(x)) {
continue;
}
auto& xname = namer.get_name(x);
@@ -281,63 +286,33 @@ inline void build_kernel(
void Compiled::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
auto& encoder = cpu::get_command_encoder(stream());
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape);
// Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Handle all broadcasting and collect function input arguments
// Collect function input arguments.
std::vector<void*> args;
std::vector<std::vector<size_t>> strides;
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant_(i)) {
continue;
}
auto& x = inputs[i];
const auto& x = inputs[i];
encoder.set_input_array(x);
args.push_back((void*)x.data<void>());
if (contiguous || is_scalar(x)) {
continue;
}
// Broadcast the input to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < shape.size() - x.ndim(); j++) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides.push_back(std::move(xstrides));
args.push_back(strides.back().data());
}
// Get the kernel name from the lib
int ndim = shape.size();
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) {
kernel_name += std::to_string(shape.size());
kernel_name += std::to_string(ndim);
}
// Get the function
auto fn_ptr = compile(kernel_name, [&]() {
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl;
@@ -347,7 +322,7 @@ void Compiled::eval_cpu(
inputs_,
outputs_,
tape_,
constant_ids_,
is_constant_,
contiguous,
ndim);
// Close extern "C"
@@ -355,19 +330,26 @@ void Compiled::eval_cpu(
return kernel.str();
});
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false);
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
for (auto& x : outputs) {
args.push_back(x.data<void>());
encoder.set_output_array(x);
}
if (!contiguous) {
args.push_back((void*)outputs[0].shape().data());
} else {
if (contiguous) {
args.push_back((void*)outputs[0].data_size());
}
auto fun = (void (*)(void**))fn_ptr;
fun(args.data());
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
encoder.dispatch([fun,
args = std::move(args),
strides = std::move(strides),
shape = std::move(shape)]() mutable {
SmallVector<int64_t*> strides_ptrs;
for (auto& s : strides) {
strides_ptrs.push_back(s.data());
}
fun(shape.data(), strides_ptrs.data(), args.data());
});
}
} // namespace mlx::core

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
@@ -13,19 +14,19 @@ namespace {
template <typename SrcT, typename DstT>
void copy_single(const array& src, array& dst) {
auto val = static_cast<DstT>(src.data<SrcT>()[0]);
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
for (int i = 0; i < dst.size(); ++i) {
dst_ptr[i] = val;
}
auto size = dst.size();
auto val = static_cast<DstT>(src_ptr[0]);
std::fill_n(dst_ptr, size, val);
}
template <typename SrcT, typename DstT>
void copy_vector(const array& src, array& dst) {
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
size_t size = src.data_size();
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
auto size = src.data_size();
std::copy(src_ptr, src_ptr + size, dst_ptr);
}
template <typename SrcT, typename DstT, int D>
@@ -60,36 +61,57 @@ void copy_general_general(
const Strides& i_strides,
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset) {
int64_t o_offset,
const std::optional<array>& dynamic_i_offset,
const std::optional<array>& dynamic_o_offset) {
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>() + o_offset;
auto i_offset_ptr =
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
auto o_offset_ptr =
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
auto size = src.size();
if (data_shape.empty()) {
auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset));
auto dst_ptr = dst.data<DstT>() + o_offset;
auto val = static_cast<DstT>(*src_ptr);
*dst_ptr = val;
return;
}
auto [shape, strides] =
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>() + o_offset;
int ndim = shape.size();
if (ndim == 1) {
copy_dims<SrcT, DstT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 2) {
copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 3) {
copy_dims<SrcT, DstT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
if (ndim < 3) {
if (i_offset_ptr) {
src_ptr += i_offset_ptr[0];
}
if (o_offset_ptr) {
dst_ptr += o_offset_ptr[0];
}
if (ndim == 1) {
copy_dims<SrcT, DstT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 2) {
copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 3) {
copy_dims<SrcT, DstT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
}
return;
}
if (i_offset_ptr) {
src_ptr += i_offset_ptr[0];
}
if (o_offset_ptr) {
dst_ptr += o_offset_ptr[0];
}
ContiguousIterator in(shape, strides[0], ndim - 3);
ContiguousIterator out(shape, strides[1], ndim - 3);
auto stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
for (int64_t elem = 0; elem < src.size(); elem += stride) {
for (int64_t elem = 0; elem < size; elem += stride) {
copy_dims<SrcT, DstT, 3>(
src_ptr + in.loc,
dst_ptr + out.loc,
@@ -105,7 +127,15 @@ void copy_general_general(
template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
src,
dst,
src.shape(),
src.strides(),
dst.strides(),
0,
0,
std::nullopt,
std::nullopt);
}
template <typename SrcT, typename DstT>
@@ -116,7 +146,9 @@ void copy_general(
const Strides& i_strides,
const Strides&,
int64_t i_offset,
int64_t o_offset) {
int64_t o_offset,
const std::optional<array>& dynamic_i_offset,
const std::optional<array>& dynamic_o_offset) {
copy_general_general<SrcT, DstT>(
src,
dst,
@@ -124,7 +156,9 @@ void copy_general(
i_strides,
make_contiguous_strides(data_shape),
i_offset,
o_offset);
o_offset,
dynamic_i_offset,
dynamic_o_offset);
}
template <typename SrcT, typename DstT>
@@ -136,7 +170,9 @@ inline void copy_general(const array& src, array& dst) {
src.strides(),
make_contiguous_strides(src.shape()),
0,
0);
0,
std::nullopt,
std::nullopt);
}
template <typename SrcT, typename DstT, typename... Args>
@@ -259,38 +295,34 @@ inline void copy_inplace_dispatch(
} // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype) {
copy_inplace_dispatch(src, dst, ctype);
void copy_cpu_inplace(
const array& src,
array& dst,
CopyType ctype,
Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch(
[src = array::unsafe_weak_copy(src),
dst = array::unsafe_weak_copy(dst),
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
}
void copy(const array& src, array& dst, CopyType ctype) {
// Allocate the output
switch (ctype) {
case CopyType::Vector:
if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
dst.copy_shared_buffer(src);
} else {
auto size = src.data_size();
dst.set_data(
allocator::malloc_or_wait(size * dst.itemsize()),
size,
src.strides(),
src.flags());
}
break;
case CopyType::Scalar:
case CopyType::General:
case CopyType::GeneralGeneral:
dst.set_data(allocator::malloc_or_wait(dst.nbytes()));
break;
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
bool donated = set_copy_output_data(src, dst, ctype);
if (donated && src.dtype() == dst.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_inplace(src, dst, ctype);
copy_cpu_inplace(src, dst, ctype, stream);
}
void copy_inplace(
void copy_cpu_inplace(
const array& src,
array& dst,
const Shape& data_shape,
@@ -298,24 +330,57 @@ void copy_inplace(
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
break;
case CopyType::Scalar:
case CopyType::Vector:
copy_inplace_dispatch(src, dst, ctype);
}
CopyType ctype,
Stream stream,
const std::optional<array>& dynamic_i_offset, /* = std::nullopt */
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
auto weak_copy_if_set = [](auto x) -> std::optional<array> {
if (x) {
return array::unsafe_weak_copy(*x);
} else {
return std::nullopt;
}
};
encoder.dispatch(
[src = array::unsafe_weak_copy(src),
dst = array::unsafe_weak_copy(dst),
data_shape,
i_strides,
o_strides,
i_offset,
o_offset,
ctype,
dynamic_i_offset = weak_copy_if_set(dynamic_i_offset),
dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset,
dynamic_i_offset,
dynamic_o_offset);
break;
case CopyType::Scalar:
case CopyType::Vector:
copy_inplace_dispatch(src, dst, ctype);
}
});
}
array contiguous_copy_cpu(const array& arr, Stream stream) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, stream);
return arr_copy;
}
} // namespace mlx::core

View File

@@ -2,16 +2,22 @@
#pragma once
#include <optional>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype);
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_cpu_inplace(
const array& src,
array& dst,
CopyType ctype,
Stream stream);
void copy_inplace(
void copy_cpu_inplace(
const array& src,
array& dst,
const Shape& data_shape,
@@ -19,6 +25,12 @@ void copy_inplace(
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
CopyType ctype,
Stream stream,
const std::optional<array>& dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_cpu(const array& arr, Stream stream);
} // namespace mlx::core

View File

@@ -0,0 +1,98 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/primitives.h"
namespace mlx::core::distributed {
std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
if (arr.flags().row_contiguous) {
return {arr, false};
} else {
return {contiguous_copy_cpu(arr, stream), true};
}
};
void AllReduce::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto donate_or_copy = [s = stream()](const array& in, array& out) {
if (in.flags().row_contiguous) {
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
return in;
} else {
array arr_copy = contiguous_copy_cpu(in, s);
out.copy_shared_buffer(arr_copy);
return arr_copy;
}
};
auto in = donate_or_copy(inputs[0], outputs[0]);
switch (reduce_type_) {
case Sum:
distributed::detail::all_sum(group(), in, outputs[0], stream());
break;
case Max:
distributed::detail::all_max(group(), in, outputs[0], stream());
break;
case Min:
distributed::detail::all_min(group(), in, outputs[0], stream());
break;
default:
throw std::runtime_error(
"Only all reduce sum, min and max are supported for now");
}
}
void AllGather::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::all_gather(group(), in, outputs[0], stream());
if (copied) {
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporary(in);
}
}
void Send::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
distributed::detail::send(group(), in, dst_, stream());
outputs[0].copy_shared_buffer(inputs[0]);
if (copied) {
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporary(in);
}
}
void Recv::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 0);
assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::recv(group(), outputs[0], src_, stream());
}
} // namespace mlx::core::distributed

174
mlx/backend/cpu/eig.cpp Normal file
View File

@@ -0,0 +1,174 @@
// Copyright © 2025 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T>
void eig_impl(
array& a,
array& vectors,
array& values,
bool compute_eigenvectors,
Stream stream) {
using OT = std::complex<T>;
auto a_ptr = a.data<T>();
auto eig_ptr = values.data<OT>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(values);
OT* vec_ptr = nullptr;
if (compute_eigenvectors) {
encoder.set_output_array(vectors);
vec_ptr = vectors.data<OT>();
}
encoder.dispatch([a_ptr,
vec_ptr,
eig_ptr,
compute_eigenvectors,
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
char jobr = 'N';
char jobl = compute_eigenvectors ? 'V' : 'N';
int n_vecs_r = 1;
int n_vecs_l = compute_eigenvectors ? N : 1;
int lwork = -1;
int info;
{
T work;
int iwork;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
}
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
auto vec_tmp_data =
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
for (size_t i = 0; i < size / (N * N); ++i) {
geev<T>(
&jobl,
&jobr,
&N,
a_ptr,
&N,
eig_tmp,
eig_tmp + N,
vec_tmp,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
&info);
for (int i = 0; i < N; ++i) {
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
}
if (vec_ptr) {
for (int i = 0; i < N; ++i) {
if (eig_ptr[i].imag() != 0) {
// This vector and the next are a pair
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
vec_ptr[(i + 1) * N + j] = {
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
}
}
}
vec_ptr += N * N;
}
a_ptr += N * N;
eig_ptr += N;
if (info != 0) {
std::stringstream msg;
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
});
encoder.add_temporary(a);
}
} // namespace
void Eig::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];
auto vectors = compute_eigenvectors_
? outputs[1]
: array(a.shape(), complex64, nullptr, {});
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
copy_cpu(
a,
a_copy,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
values.set_data(allocator::malloc(values.nbytes()));
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
// are in the columns of the output
auto flags = vectors.flags();
auto strides = vectors.strides();
auto ndim = a.ndim();
std::swap(strides[ndim - 1], strides[ndim - 2]);
if (a.size() > 1) {
flags.row_contiguous = false;
if (ndim > 2) {
flags.col_contiguous = false;
} else {
flags.col_contiguous = true;
}
}
vectors.set_data(
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
}
switch (a.dtype()) {
case float32:
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
}
}
} // namespace mlx::core

View File

@@ -3,6 +3,7 @@
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
@@ -11,35 +12,173 @@ namespace mlx::core {
namespace {
void ssyevd(
char jobz,
char uplo,
float* a,
int N,
float* w,
float* work,
int lwork,
int* iwork,
int liwork) {
template <typename T, class Enable = void>
struct EighWork {};
template <typename T>
struct EighWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using R = T;
char jobz;
char uplo;
int N;
int lwork;
int liwork;
int info;
MLX_LAPACK_FUNC(ssyevd)
(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ a,
/* lda = */ &N,
/* w = */ w,
/* work = */ work,
/* lwork = */ &lwork,
/* iwork = */ iwork,
/* liwork = */ &liwork,
/* info = */ &info);
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
std::vector<array::Data> buffers;
EighWork(char jobz_, char uplo_, int N_)
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
T work;
int iwork;
syevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work);
liwork = iwork;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
}
void run(T* vectors, T* values) {
syevd<T>(
&jobz,
&uplo,
&N,
vectors,
&N,
values,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<int*>(buffers[1].buffer.raw_ptr()),
&liwork,
&info);
}
};
template <>
struct EighWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
char jobz;
char uplo;
int N;
int lwork;
int lrwork;
int liwork;
int info;
std::vector<array::Data> buffers;
EighWork(char jobz_, char uplo_, int N_)
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
T work;
R rwork;
int iwork;
heevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&rwork,
&lrwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work.real());
lrwork = static_cast<int>(rwork);
liwork = iwork;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
}
void run(T* vectors, R* values) {
heevd<T>(
&jobz,
&uplo,
&N,
vectors,
&N,
values,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<R*>(buffers[1].buffer.raw_ptr()),
&lrwork,
static_cast<int*>(buffers[2].buffer.raw_ptr()),
&liwork,
&info);
if (jobz == 'V') {
// We have pre-transposed the vectors but we also must conjugate them
// when they are complex.
//
// We could vectorize this but it is so fast in comparison to heevd that
// it doesn't really matter.
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
*vectors = std::conj(*vectors);
vectors++;
}
}
}
}
};
template <typename T>
void eigh_impl(
array& vectors,
array& values,
const std::string& uplo,
bool compute_eigenvectors,
Stream stream) {
using R = typename EighWork<T>::R;
auto vec_ptr = vectors.data<T>();
auto eig_ptr = values.data<R>();
char jobz = compute_eigenvectors ? 'V' : 'N';
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(vectors);
encoder.set_output_array(values);
encoder.dispatch([vec_ptr,
eig_ptr,
jobz,
uplo = uplo[0],
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
EighWork<T> work(jobz, uplo, N);
// Work loop
for (size_t i = 0; i < size / (N * N); ++i) {
work.run(vec_ptr, eig_ptr);
vec_ptr += N * N;
eig_ptr += N;
if (work.info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< work.info;
throw std::runtime_error(msg.str());
}
}
});
if (!compute_eigenvectors) {
encoder.add_temporary(vectors);
}
}
@@ -55,12 +194,13 @@ void Eigh::eval_cpu(
? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes()));
values.set_data(allocator::malloc(values.nbytes()));
copy(
copy_cpu(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
@@ -78,41 +218,23 @@ void Eigh::eval_cpu(
flags.col_contiguous = true;
}
}
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
vectors.copy_shared_buffer(vectors, strides, flags, vectors.data_size());
}
auto vec_ptr = vectors.data<float>();
auto eig_ptr = values.data<float>();
char jobz = compute_eigenvectors_ ? 'V' : 'N';
auto N = a.shape(-1);
// Work query
int lwork;
int liwork;
{
float work;
int iwork;
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < a.size() / (N * N); ++i) {
ssyevd(
jobz,
uplo_[0],
vec_ptr,
N,
eig_ptr,
static_cast<float*>(work_buf.buffer.raw_ptr()),
lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
liwork);
vec_ptr += N * N;
eig_ptr += N;
switch (a.dtype()) {
case float32:
eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_, stream());
break;
case float64:
eigh_impl<double>(
vectors, values, uplo_, compute_eigenvectors_, stream());
break;
case complex64:
eigh_impl<std::complex<float>>(
vectors, values, uplo_, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error(
"[Eigh::eval_cpu] only supports float32 or float64.");
}
}

View File

@@ -0,0 +1,16 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/encoder.h"
namespace mlx::core::cpu {
CommandEncoder& get_command_encoder(Stream stream) {
static std::unordered_map<int, CommandEncoder> encoder_map;
auto it = encoder_map.find(stream.index);
if (it == encoder_map.end()) {
it = encoder_map.emplace(stream.index, stream).first;
}
return it->second;
}
} // namespace mlx::core::cpu

67
mlx/backend/cpu/encoder.h Normal file
View File

@@ -0,0 +1,67 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <unordered_map>
#include "mlx/array.h"
#include "mlx/scheduler.h"
namespace mlx::core::cpu {
// Number of dispatches per scheduler task
constexpr int DISPATCHES_PER_TASK = 10;
struct CommandEncoder {
CommandEncoder(Stream stream) : stream_(stream) {}
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
CommandEncoder(CommandEncoder&&) = delete;
CommandEncoder& operator=(CommandEncoder&&) = delete;
void set_input_array(const array& a) {}
void set_output_array(array& a) {}
// Hold onto a temporary until any already scheduled tasks which use it as
// an input are complete.
void add_temporary(array arr) {
temporaries_.push_back(std::move(arr));
}
void add_temporaries(std::vector<array> arrays) {
temporaries_.insert(
temporaries_.end(),
std::make_move_iterator(arrays.begin()),
std::make_move_iterator(arrays.end()));
}
std::vector<array>& temporaries() {
return temporaries_;
}
template <class F, class... Args>
void dispatch(F&& f, Args&&... args) {
num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
if (num_ops_ == 0) {
scheduler::notify_new_task(stream_);
auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
task();
scheduler::notify_task_completion(s);
};
scheduler::enqueue(stream_, std::move(task_wrap));
} else {
scheduler::enqueue(stream_, std::move(task));
}
}
private:
Stream stream_;
std::vector<array> temporaries_;
int num_ops_{0};
};
CommandEncoder& get_command_encoder(Stream stream);
} // namespace mlx::core::cpu

40
mlx/backend/cpu/eval.cpp Normal file
View File

@@ -0,0 +1,40 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/eval.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
namespace mlx::core::cpu {
void eval(array& arr) {
auto s = arr.primitive().stream();
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
arr.primitive().eval_cpu(arr.inputs(), outputs);
}
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
auto& encoder = cpu::get_command_encoder(s);
encoder.dispatch([buffers = std::move(buffers),
temps = std::move(encoder.temporaries())]() {});
}
} // namespace mlx::core::cpu

12
mlx/backend/cpu/eval.h Normal file
View File

@@ -0,0 +1,12 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/stream.h"
namespace mlx::core::cpu {
void eval(array& arr);
} // namespace mlx::core::cpu

View File

@@ -4,6 +4,7 @@
#include "mlx/3rdparty/pocketfft.h"
#include "mlx/allocator.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -21,7 +22,7 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
s *= out.itemsize();
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
std::vector<size_t> shape;
if (out.dtype() == float32) {
@@ -38,46 +39,78 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
});
scale /= nelem;
}
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
if (in.dtype() == complex64 && out.dtype() == complex64) {
auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::c2c(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
encoder.dispatch([shape = std::move(shape),
strides_in = std::move(strides_in),
strides_out = std::move(strides_out),
axes = axes_,
inverse = inverse_,
in_ptr,
out_ptr,
scale]() {
pocketfft::c2c(
shape,
strides_in,
strides_out,
axes,
!inverse,
in_ptr,
out_ptr,
scale);
});
} else if (in.dtype() == float32 && out.dtype() == complex64) {
auto in_ptr = in.data<float>();
auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::r2c(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
encoder.dispatch([shape = std::move(shape),
strides_in = std::move(strides_in),
strides_out = std::move(strides_out),
axes = axes_,
inverse = inverse_,
in_ptr,
out_ptr,
scale]() {
pocketfft::r2c(
shape,
strides_in,
strides_out,
axes,
!inverse,
in_ptr,
out_ptr,
scale);
});
} else if (in.dtype() == complex64 && out.dtype() == float32) {
auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr = out.data<float>();
pocketfft::c2r(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
encoder.dispatch([shape = std::move(shape),
strides_in = std::move(strides_in),
strides_out = std::move(strides_out),
axes = axes_,
inverse = inverse_,
in_ptr,
out_ptr,
scale]() {
pocketfft::c2r(
shape,
strides_in,
strides_out,
axes,
!inverse,
in_ptr,
out_ptr,
scale);
});
} else {
throw std::runtime_error(
"[FFT] Received unexpected input and output type combination.");

View File

@@ -7,14 +7,20 @@ namespace mlx::core {
template <typename T>
void matmul(
const array& a,
const array& b,
array& out,
const T* a,
const T* b,
T* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides);
} // namespace mlx::core

View File

@@ -9,39 +9,46 @@
namespace mlx::core {
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
uint32_t size_bits = size_of(mlx_dtype) * 8;
switch (kindof(mlx_dtype)) {
case Dtype::Kind::b:
return BNNSDataTypeBoolean;
case Dtype::Kind::u:
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
case Dtype::Kind::i:
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
case Dtype::Kind::f:
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
case Dtype::Kind::V:
return BNNSDataTypeBFloat16;
case Dtype::Kind::c:
throw std::invalid_argument("BNNS does not support complex types");
}
template <typename T>
constexpr BNNSDataType to_bnns_dtype();
template <>
constexpr BNNSDataType to_bnns_dtype<float>() {
return BNNSDataType(BNNSDataTypeFloatBit | 32);
}
template <>
constexpr BNNSDataType to_bnns_dtype<float16_t>() {
return BNNSDataType(BNNSDataTypeFloatBit | 16);
}
template <>
constexpr BNNSDataType to_bnns_dtype<bfloat16_t>() {
return BNNSDataTypeBFloat16;
}
template <typename T>
void matmul_bnns(
const array& a,
const array& b,
array& out,
const T* a,
const T* b,
T* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
@@ -115,14 +122,14 @@ void matmul_bnns(
auto bnns_filter =
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
for (int i = 0; i < batch_size; ++i) {
BNNSFilterApplyTwoInput(
bnns_filter,
a.data<uint8_t>() +
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
b.data<uint8_t>() +
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
out.data<uint8_t>() + M * N * i * out.itemsize());
reinterpret_cast<const uint8_t*>(
a + elem_to_loc(M * K * i, a_shape, a_strides)),
reinterpret_cast<const uint8_t*>(
b + elem_to_loc(K * N * i, b_shape, b_strides)),
reinterpret_cast<uint8_t*>(out + M * N * i));
}
BNNSFilterDestroy(bnns_filter);
@@ -131,30 +138,72 @@ void matmul_bnns(
template <>
void matmul<float16_t>(
const array& a,
const array& b,
array& out,
const float16_t* a,
const float16_t* b,
float16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta) {
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
matmul_bnns(
a,
b,
out,
a_transposed,
b_transposed,
lda,
ldb,
ldc,
alpha,
beta,
batch_size,
a_shape,
a_strides,
b_shape,
b_strides);
}
template <>
void matmul<bfloat16_t>(
const array& a,
const array& b,
array& out,
const bfloat16_t* a,
const bfloat16_t* b,
bfloat16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta) {
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
matmul_bnns(
a,
b,
out,
a_transposed,
b_transposed,
lda,
ldb,
ldc,
alpha,
beta,
batch_size,
a_shape,
a_strides,
b_shape,
b_strides);
}
} // namespace mlx::core

View File

@@ -8,20 +8,27 @@ namespace mlx::core {
template <>
void matmul<float>(
const array& a,
const array& b,
array& out,
const float* a,
const float* b,
float* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < (a.size() / (M * K)); ++i) {
for (int i = 0; i < batch_size; ++i) {
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
@@ -29,34 +36,40 @@ void matmul<float>(
M,
N,
K,
alpha, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
alpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
beta, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
beta,
out + M * N * i,
ldc);
}
}
template <>
void matmul<double>(
const array& a,
const array& b,
array& out,
const double* a,
const double* b,
double* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < (a.size() / (M * K)); ++i) {
for (int i = 0; i < batch_size; ++i) {
cblas_dgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
@@ -64,15 +77,14 @@ void matmul<double>(
M,
N,
K,
alpha, // alpha
a.data<double>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
alpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b.data<double>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
beta, // beta
out.data<double>() + M * N * i,
out.shape(-1) // ldc
);
beta,
out + M * N * i,
ldc);
}
}

View File

@@ -1,21 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const array&,
const array&,
array&,
bool,
bool,
size_t,
size_t,
float,
float) {
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
}
} // namespace mlx::core

Some files were not shown because too many files have changed in this diff Show More