Compare commits

...

138 Commits

Author SHA1 Message Date
Awni Hannun
02a9fc7bfa Patch bump (#1067)
* version

* use 0.12.2
2024-05-02 16:37:31 -07:00
Jagrit Digani
f390957685 Block sparse mm (#1058) 2024-05-02 14:03:58 -07:00
Angelos Katharopoulos
17f57df797 Improvements in the quantizer and dequantization kernel (#1061) 2024-05-01 18:19:11 -07:00
Awni Hannun
7f7b9662ea Fix leak for multi-output primitives which are never detached (#1059)
* fix multi output leak

* ignore arrays that will be detached

* add some comments

* stray print
2024-05-01 07:31:45 -07:00
Awni Hannun
19bef39f5c Add a mx.metal.device_info (#1060)
* device inof

* add variant

* fix linux

* fix doc
2024-04-30 15:47:27 -07:00
Nripesh Niketan
a30e7ed2da feat: metal formatting and pre-commit bump (#1038)
* feat: metal formatting and pre-commit bump

* add guards

* update

* more guards

* more guards

* smakk fix

* Refactor instantiation of ternary types in ternary.metal

* fix scan.metal
2024-04-30 07:18:09 -07:00
Angelos Katharopoulos
8db7161c94 Bug fix in quantize (#1054) 2024-04-29 20:55:04 -07:00
Awni Hannun
09f1777896 fix slice update indexing (#1053) 2024-04-29 12:17:40 -07:00
Jacket
490c0c4fdc [Fix] expand axes for dimension with integer indices in mlx_slice_update (#1035)
* Not sure if this is correct

* Format

* Edit tests

* Add negative test

* Format

* add one more test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-29 07:57:28 -07:00
Rifur13
c4a471c99d Add groups to Conv1d (#948)
* Add conv1d grouped convs on CPU

* Add GPU support

* Parallelize inside metal kernel

* clenaup

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* New unfold kernel + remove unused code

* Remove copy and refactor

* Update vjp and reuse steel gemm

* Fixed groups on cpu

* Fix metal validation

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-04-27 06:24:57 -07:00
Awni Hannun
86f495985b Add bitwise ops (#1037)
* bitwise ops

* fix tests
2024-04-26 22:03:42 -07:00
Awni Hannun
67d1894759 fix order device -> scheduler (#1039) 2024-04-26 13:46:41 -07:00
Awni Hannun
5bfe89bdb1 Cpp docs (#1036)
* start of C++ docs

* fix stream doc

* only include ops for now
2024-04-26 12:56:05 -07:00
Angelos Katharopoulos
82463e9938 Bump the version to 0.12 (#1034) 2024-04-25 14:18:08 -07:00
Awni Hannun
771575d27b Expose function to clear memory cache (#1032)
* expose function to clear memory cache

* fix linux build

* fix metal tests
2024-04-24 16:48:51 -07:00
Angelos Katharopoulos
20a01bbd9f Simplifying and improving qmm (#1030) 2024-04-24 13:07:45 -07:00
Angelos Katharopoulos
ec8578d41a Fix quantization of all 0s (#1028) 2024-04-24 00:40:42 -07:00
Aneesh Shetty
d0dbfe0b97 Adds radians and degrees (#1011) 2024-04-22 11:17:49 -07:00
Awni Hannun
3d405fb3b1 Add synchronize function (#1006)
* add synchronize function

* fix linux

* fix linux

* fix and fix docs

* fix test

* try synchronize in stream destroy

* synchronize works for both cpu and gpu
2024-04-22 08:25:46 -07:00
Angelos Katharopoulos
b0012cdd0f Bump the patch version for the quants (#1018) 2024-04-19 20:28:34 -07:00
Angelos Katharopoulos
84d61d27aa Make sure 0 is represented in the quantization (#1016) 2024-04-19 19:47:26 -07:00
Awni Hannun
ed83908931 fix gguf loading quants (#1014)
* fix gguf loading quants

* fix nanobind install

* actual fix
2024-04-19 12:24:07 -07:00
Angelos Katharopoulos
ef5f7d1aea Fix buffer protocol buffer size designation (#1010) 2024-04-19 06:06:13 -07:00
Awni Hannun
090ff659dc bump (#1007) 2024-04-18 13:18:43 -07:00
Jagrit Digani
85c8a91a27 Fix mask broadcasting bug and add relevant test (#1003) 2024-04-17 17:33:48 -07:00
Piotr Rybiec
581b699ac9 avgpool, not maxpool (#1002) 2024-04-17 08:26:22 -07:00
Awni Hannun
8a0677d56d Shared events for synchronization + async eval (#998)
* more async eval

* fix rebase

* try correct async eval

* fix async

* more tests for async eval

* use shared events for synchronization

* comment + cleanup

* with autorelease pool

* fix no metal build

* fix compile

* fix patch

* don't eval if asyn evale'd

* don't use is_evaled

* comments

* more multi stream tests

* try and cleanup use of is_evaled

* use a status flag
2024-04-17 06:16:02 -07:00
Jagrit Digani
b18468bf81 Masked mm (#978)
* Add block masked matmul op and primitive
2024-04-16 14:45:39 -07:00
Shiyu
107ba2891a gelu tanh approx (#989)
* gelu tanh approx

* gelu tanh approx

* replace gelu approx with tanh approach

* fix comments

* fix comment
2024-04-15 19:49:00 -07:00
Awni Hannun
cd9e184529 Quantize embedding (#994)
* quantize embedding

* rename as_linear + comment

* consistency in docs

* fix test
2024-04-15 16:42:10 -07:00
Alex Barron
2e7c02d5cd Metal FFT for powers of 2 up to 2048 (#915)
* add Metal FFT for powers of 2

* skip GPU test on linux

* fix contiguity bug

* address comments

* Update mlx/backend/metal/fft.cpp

* Update mlx/backend/metal/fft.cpp

* fix bug in synch

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-11 21:40:06 -07:00
Awni Hannun
ae18326533 No copy command encoder (#986)
* no copy command encoder

* up layer norm test tolerances
2024-04-11 21:15:36 -07:00
Alex Shepard
91eba8e485 fix for grammatical typo in docs (#988)
thanks for mlx!
2024-04-11 17:02:06 -07:00
Awni Hannun
d07e295c62 bumpity bump (#987) 2024-04-11 12:48:52 -07:00
Angelos Katharopoulos
dce4bd74a4 Add ArrayDesc destructor to avoid possible stack overflow (#982) 2024-04-11 11:37:02 -07:00
Nripesh Niketan
ffff671273 Update pre-commit hooks (#984) 2024-04-11 07:27:53 -07:00
Awni Hannun
12d4507ee3 Explicit barriers with concurrent dispatch (#977) 2024-04-10 21:45:31 -07:00
Awni Hannun
8580d997ff Try a stack-based DFS for eval (#980)
* rebase

* nit

* fix eval in vmap
2024-04-10 17:05:13 -07:00
Shiyu
061cf9a4ce Upsample with bicubic interpolation (#967) 2024-04-10 15:47:22 -07:00
Awni Hannun
99abb9eff4 Async eval (#972) 2024-04-09 18:34:00 -07:00
Luca Arnaboldi
fffe072028 Implementation of mlx.random.multivariate_normal (#502) (#877)
* Implementation of mlx.random.multivariate_normal (#502)

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Updated typo in docstring

* Restricted multivariate_normal to  float32

* Generic mean and variance shapes

* Review edits

* Update mlx/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Test for ndim of mean and cov

* nits

* smaller size for test

* fix broadcasted sampling

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-09 13:50:12 -07:00
Abe Leininger
a1a31eed27 Add mx.meshgrid (#961) 2024-04-09 11:43:08 -07:00
Awni Hannun
ae812350f9 use string (#976) 2024-04-09 11:22:00 -07:00
Awni Hannun
b63ef10a7f Extensions (#962)
* start to fix extensions

* mostly fixed extensions

* fix extension build

* couple more nits
2024-04-09 08:50:36 -07:00
Awni Hannun
42afe27e12 std and expm1 (#973)
* std and expm1

* actually add expm1

* fix linux

* fix vjp

* relax tol for linux test

* Add it to the compilable primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-08 14:26:01 -07:00
Awni Hannun
76e63212ff Enable bfloat scan (#974)
* enable bfloat scan
* fix tests
2024-04-08 12:29:19 -07:00
Awni Hannun
aac2f9fb61 Improve profiling with gpu tracing (#969)
* improve profiling with gpu tracing

* fix for linux

* nit

* doc fix

* fix example
2024-04-07 21:47:43 -07:00
Awni Hannun
bddf23f175 patch bump (#956) 2024-04-04 11:56:37 -07:00
Awni Hannun
039da779d1 No quant reshape (#957)
* precise option on cpu

* remove print

* remove reshape in quant matmul

* no quant reshape
2024-04-04 11:52:12 -07:00
Awni Hannun
d88d2124b5 segfaut layer norm grad (#955) 2024-04-04 10:59:15 -07:00
Awni Hannun
e142aaf8a1 Option for precise softmax (#953)
* precise softmax

* Add an equivalency check

* Make the threadgroup memory definition fixed

* precise cpu softmax

* precise option on cpu

* remove print

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-04 08:32:35 -07:00
AmirHossein_Razlighi
0caf35f4b8 Better exceptions in case of invalid operations on mlx.core.array (#910) (#926)
* Nicer exceptions for ops on non-arrays
2024-04-02 21:11:24 -07:00
Angelos Katharopoulos
3fc993f82d Properly handle negative axes in python vmap (#944) 2024-04-02 18:07:23 -07:00
Awni Hannun
741eb28443 fix a couple bugs (#952) 2024-04-02 12:07:41 -07:00
Angelos Katharopoulos
1a87dc5ea8 Fix compile fusion for multi-output edge cases (#950)
* Fix compile fusion for multi-output edge cases

* Add a test for multi-output compile
2024-04-02 08:42:31 -07:00
Awni Hannun
2427fa171e Fix cpu compile (#934)
* fix one cpu bug, test for another

* format hooks

* simplify contiguity check for cpu compile

* fix

* add back donation

* comment
2024-04-01 17:37:12 -07:00
Jagrit Digani
639e06e1f3 Indexing bug fix (#947)
* Fix axes accounting

* Add tests
2024-04-01 12:18:50 -07:00
Angelos Katharopoulos
02fedbf1da Fix array initialization from list (#942)
* Fix array initialization from list

* Change the error message in the test
2024-04-01 06:27:52 -07:00
Angelos Katharopoulos
110d9b149d Layer norm grad fix donation bug (#941)
* add layer norm grad test

* Fix donation bug in layernorm vjp

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-01 06:15:50 -07:00
Angelos Katharopoulos
9cbff5ec1d Fix typo in qmm check (#940) 2024-03-31 19:15:44 -07:00
Suvan Kumar
433c0206b0 Update saving_and_loading.rst (#929)
Update saving / load docs.
2024-03-30 14:30:06 -07:00
Awni Hannun
8915901966 Donation bug (#933)
* donation

* buf

* fix bug in softmax

* comment

* remove print
2024-03-30 10:08:54 -07:00
AmirHossein_Razlighi
f48bc496c7 Comparing python objects (such as list/tuple) with mlx.core.array (#920)
* add implicit conversion of list to array for equality constraint

* add tests for array equality

* add test for tuple and array equality

* return False if __eq__ arg is list or tuple

* write tests for equality

* update the rule of comparison for __ge__/__gt__/__lt__/__le__

* add a helper function for detecting mlx.core.array

* return true in case fo inequality

* debug minor issue regarding detecting mlx array

* add tests for inequality comparisons

* add name for contribution

* reformat files using pre-commit

* update tests for float

* update tests for inequality

* raise exception in case of invalid comparisons

* use isinstance instead of string comparison

* replace "is_convirtable_to_array" with previous logic

* remove throwing exceptions for other operations

* just a comment

* minor changes for efficiency

* optimize a utils function

* change the function name

* Update ACKNOWLEDGMENTS.md

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-29 06:52:30 -07:00
Cheng
913b19329c Add missing && when forwarding args (#925)
Without the && args would be copied and perfect forwarding won't work.
2024-03-29 06:48:29 -07:00
Awni Hannun
d8cb3128f6 bump (#924)
* bump

* fix version
2024-03-28 16:14:55 -07:00
Angelos Katharopoulos
5f9ba3019f Fix qmm_t for unaligned cases (#923) 2024-03-28 15:34:57 -07:00
Cheng
46caf0bef0 Remove unnecessary string copies (#891)
1. Use string_view instead of string when there is no need for copy.
2. Otherwise move string when possible.
2024-03-28 13:14:59 -07:00
Jack Mousseau
45f636e759 Add Metal debug option and capture functions (#707)
* Add Metal debug option and capture functions

* Add brief Metal debugger documentation

* doc nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-28 09:40:31 -07:00
Cheng
a7b404ff53 Use uintptr_t instead of size_t to store funtion id (#916)
Also does some small cleanup of the compile cache code.
2024-03-28 06:37:59 -07:00
Angelos Katharopoulos
c4fd0e5ede Fixes #918 bug in compile_tests (#919) 2024-03-27 22:37:37 -07:00
Cheng
bab5386306 Make ops aware of rvalues: astype/as_strided/copy/full (#895)
When compositing transforms lots of temporary of arrays will be created
and passed to next primitive, and by making ops accepting args by value
we can avoid lots of copies of temporary arrays.
2024-03-27 22:35:55 -07:00
Angelos Katharopoulos
aca7584635 Fix OOB read in qmv when non-divisible by blocksize (#917) 2024-03-27 22:18:35 -07:00
AmirHossein_Razlighi
d611251502 Support Chaining for some of functionalities of nn.Module (#885) (#897)
* add chaining support for some of the functionalities of "nn.Module"

* reformat

* change the return types

* remove return types

* add return type with forward referencing

* add tests for chaining

* add name to contributors

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update docstring

* update docstrings

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-27 19:58:29 -07:00
Cheng
f30b659291 Make MLX build on x64 macOS (#901)
The arm64 macbook pros are heavy and I usually care my intel one for
mobile, it would be nice if I can play with MLX on it.

To build with x64, user must pass `MLX_ENABLE_X64_MAC` to cmake:
CMAKE_ARGS='-DMLX_ENABLE_X64_MAC=ON' python setup.py
2024-03-27 06:14:29 -07:00
Cheng
90dfa43ff1 Don't use make_unique to create shared_ptr (#902)
The code compiled because shared_ptr's constructor actually accepts
unique_ptr.
2024-03-27 06:13:29 -07:00
Awni Hannun
dc175f08d3 Fix race in multi-stream eval (#911)
* maybe fix race

* comment
2024-03-26 16:36:36 -07:00
Angelos Katharopoulos
29221fa238 Implement vjps for some primitives in the fast namespace (#883)
* Implement rope vjp in terms of rope
* RMSNormVJP primitive and kernel
* Add LayerNormVJP primitive and kernel
2024-03-26 16:35:34 -07:00
Cheng
a789685c63 Remove duplicate defines of StreamOrDevice and is_big_endian (#892) 2024-03-26 15:15:11 -07:00
Jagrit Digani
240d10699c Implement negative padding in conv with slicing (#907)
* Implement negative padding with slicing

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni@apple.com>

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-26 14:59:19 -07:00
Jagrit Digani
925014b661 Fix multiblock sort limits (#906)
* Fix multiblock sort limits

* Fix metal validation error
2024-03-26 14:00:00 -07:00
Abdussamet Türker
5611e1a95e Fix unsqueeze with None (#899)
* Fix unsqueeze with None

* Clean unnecessary files
2024-03-26 13:59:44 -07:00
Awni Hannun
570f2bf29e pick up preivously set attributes (#905) 2024-03-26 11:19:59 -07:00
Angelos Katharopoulos
9948eddf11 Fix nan and improve speed for qvm (#903) 2024-03-26 10:41:45 -07:00
Luca Arnaboldi
a3ee03da01 Fixing random.normal for half-precision dtype #642 (#904)
* Fixing random.normal for half-precision dtype #642

* Update python/tests/test_random.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-26 09:58:27 -07:00
Cheng
28fcd2b519 Add missing && when forwarding args (#894)
Without the && args would be copied and perfect forwarding won't work.

Also add template utils to make sure the function only forwards array
and not vector<array>.
2024-03-25 14:55:54 -07:00
Jack Mousseau
8e686764ac Ensure shape dimensions are within supported integer range (#566) (#704)
* Ensure shape dimensions are within supported integer range (#566)

* fix build

* fix rebase bug

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 13:29:45 -07:00
Daniel Strobusch
479051ce1c add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 12:32:59 -07:00
Awni Hannun
bfb5bad4f0 patch (#893) 2024-03-24 21:03:59 -07:00
Awni Hannun
1e16331d9c post nanobind docs fixes and some updates (#889)
* post nanobind docs fixes and some updates

* one more doc nit

* fix for stubs and latex
2024-03-24 15:03:27 -07:00
Awni Hannun
be98f4ab6b Reduce a little overhead (#871)
* some small overhead improvements

* use result_type in rms_norm

* remove release force

* fix + use non-vector version

* revert compile change

* fix ops

* a little more overhead

* a little more cleanup and overhead
2024-03-22 17:29:36 -07:00
Angelos Katharopoulos
6ee1112f30 Fix copy donation and add partial rope (#881) 2024-03-22 17:28:26 -07:00
Jagrit Digani
8e5a5a1ccd Set item bug fix (#879)
* set item shaping bug fix

* Add extra tests
2024-03-22 12:11:17 -07:00
Angelos Katharopoulos
fcda3a0e66 Increase test tolerance for fast.layer_norm (#880) 2024-03-22 12:10:27 -07:00
Cheng
9663c22fe9 Do not store iostream in shared_ptr (#872)
There is no need to store iostream in shared_ptr, doing so adds the cost
of a heap allocation.
2024-03-22 06:54:45 -07:00
Cheng
f0ae00da12 Reduce implicit copies in make_array (#874)
1. Move shapes into outputs instead of copying them.
2. Pass primitive by const ref as it is always copied into outputs, which
   removes a copy when calling make_array.
2024-03-22 06:29:16 -07:00
Awni Hannun
44390bd3d0 Bump (#869)
* bump

* fix none in a few ops
2024-03-21 13:56:56 -07:00
Angelos Katharopoulos
2225374060 Adds mx.fast.layer_norm (#870) 2024-03-21 13:55:51 -07:00
nicolov
105d236889 Add vmap for SVD and inverse (#849) 2024-03-21 13:18:27 -07:00
Angelos Katharopoulos
53e6a9367c Use reshape and transpose for non-overlapping pooling windows (#867) 2024-03-21 10:21:03 -07:00
Chime Ogbuji
f5a1582fe8 Add minimum for cosine decay function (#859)
* Add minimum for cosine decay function

* Update python/mlx/optimizers/schedulers.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-21 07:33:29 -07:00
Awni Hannun
a54f06b16f Fast RMS Norm (#862)
* fast rmsnorm

* no rms gpu

* kernel

* fix shared mem

* looped rms and donation in softmax

* Make the squaring in float32 to avoid underflow

* Fix the default StreamOrDevice for rope and rms_norm in fast

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-21 07:20:54 -07:00
Cheng
4650d94d98 Add missing && in eval (#864)
Without the && args would be copied and perfect forwarding won't work.

To avoid eval calling itself recursively, the vector version of eval is
changed to take by value instead, which will save a copy of array when a
rvalue is passed.
2024-03-21 06:15:48 -07:00
Jagrit Digani
a5681ebc52 Update set item (#861)
* Update mlx_set_item to handle regular slices without expanding

* Refactor ellipsis handling

* Route mlx_set_item to slice_update where possible

* Update mlx_scatter_args_slice

* Don't route to gather if no array indices
2024-03-21 02:48:13 -07:00
Cheng
e849b3424a Do not use static constexpr in header (#863)
Doing so results in each compilation unit (.cpp file) having its own
copy of the variable, while inline constexpr makes sure there is only
one copy.
2024-03-20 21:28:05 -07:00
Jagrit Digani
b219d12a6b Check edge case handling in row reduce med kernel (#858) 2024-03-20 11:37:58 -07:00
Jagrit Digani
cec8661113 Add a SliceUpdate op and primitive (#850)
* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
2024-03-20 10:39:25 -07:00
Cheng
73a8c090e0 Pass shape and inputs by value in array's constructor (#853)
Since the shape and inputs are always saved as copy in ArrayDesc, we can
unify array's constructors to just take the arguments by value.

There are 2 cases:
1. When shape is a lvalue, it will be copied into array's constructor and
   then moved into ArrayDesc's member. So only 1 copy happens.
2. When shape is a rvalue, it will be moved into array's constructor and
   then moved into ArrayDesc's member. So no copy happens.

So having 1 constructor that takes by value is equivalent to having 2
constructors that const reference and rvalue separately.
2024-03-20 07:54:30 -07:00
Md. Rasel Mandol
db6796ac61 simple typo fille (#848) 2024-03-19 06:15:17 -07:00
Awni Hannun
9a8ee00246 Switch to nanobind (#839)
* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
2024-03-18 20:12:25 -07:00
Cheng
d39ed54f8e Some C++ code are not needed (#841)
1. Anonymous namespace means internal linkage, static keyword is not needed.
2. The default constructor of std::shared_ptr initializes the pointer to
   nullptr, you don't need to explicitly set it.
2024-03-18 17:04:10 -07:00
Awni Hannun
16546c70d8 No reshape rope (#838)
* no reshape rope

* no reshape rope
2024-03-18 17:03:07 -07:00
nicolov
eaba55c9bf Add matrix inversion primitive (#822) 2024-03-15 06:34:36 -07:00
Awni Hannun
19ec023256 vmap matmul and admm (#836) 2024-03-14 14:38:22 -07:00
Awni Hannun
63ab0ab580 version (#835) 2024-03-14 12:20:40 -07:00
Jagrit Digani
8dfc376c00 Strided reduce specialization for small reductions (#826)
* Add small column / general reduction specialization
2024-03-14 09:16:53 -07:00
Angelos Katharopoulos
1efee9db09 Add types and order in kernel name (#831) 2024-03-13 20:34:06 -07:00
Awni Hannun
43abc402d8 route to fallback (#828) 2024-03-13 19:56:04 -07:00
Angelos Katharopoulos
3f8b1668c4 Make reshape faster for row_contiguous cases (#829) 2024-03-13 16:22:03 -07:00
Angelos Katharopoulos
76c919b4ec NumberOfElements for shapeless compile and vmap fixes (#802) 2024-03-13 10:34:14 -07:00
Angelos Katharopoulos
29d0c10ee5 Reshape improvement (#818) 2024-03-12 17:54:31 -07:00
Jagrit Digani
5ad133f8bb No copy gems (#801)
* Enable collapsing batch dims in gemm
* Update gemm to only make copies when neither of the last 2 axes are contiguous
* Update addmm to support gemv shapes
* Update addmm to support irregular batch strides
* Update tests
2024-03-12 13:13:41 -07:00
nicolov
d0c544a868 Add SVD primitive (#809)
Add SVD op using Accelerate's LAPACK following
https://developer.apple.com/documentation/accelerate/
compressing_an_image_using_linear_algebra

Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
2024-03-12 12:30:11 -07:00
Daniel Falbel
ffb19df3c0 Fix docstring for correctly rendering (#820) 2024-03-12 11:46:44 -07:00
Awni Hannun
8b7532b9ab fix scatter (#821) 2024-03-12 11:42:07 -07:00
Awni Hannun
366478c560 fix modules with dict (#819) 2024-03-12 08:54:06 -07:00
Justin Deschenaux
8e5600022a Implement RNN, GRU, LSTM (#268)
* RNN base implementation

* Address comments+format

* nits in docs

* add tests for prb

* fix test

* add a couple tests

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-11 21:14:44 -07:00
Awni Hannun
0e95b64942 Fix bug in tape order during simplify (#816)
* fix bug in tape order during simplify

* properly fix compile

* last bug
2024-03-11 17:29:05 -07:00
nicolov
0ae22b915b Remove code duplication in reduce ops (#793)
* Remove code duplication in reduce ops

* Remove the unnecessary lambda

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-11 10:57:07 -07:00
Awni Hannun
7c441600fe Compile stride bug (#812)
* fix compile stride bug

* revert sdpa fix

* fix cpu

* fix bug with simplifying outputs
2024-03-11 06:31:31 -07:00
Awni Hannun
a4d290adb9 Remove depth traversal (#813)
* no depth traversal

* counter outside loop
2024-03-09 20:21:32 -08:00
Awni Hannun
28301807c2 Version bump and os error (#807) 2024-03-07 13:57:58 -08:00
Awni Hannun
74ed0974b3 Support 13.0+ with xcode 14.3 (#806)
* Support 13.0+ with xcode 14.3

* revert revert
2024-03-07 13:27:57 -08:00
Jagrit Digani
ec8a4864fa Fix SDPA kernel bug on Mac OS 13.3 SDK (#805)
* Move sdpa kernel to allocate tgp mem statically and allow macOS 13.3 SDK builds

* Style
2024-03-07 10:18:09 -08:00
Awni Hannun
b7588fd5d7 fix inplace to not make a shallow copy (#804) 2024-03-07 09:34:11 -08:00
Awni Hannun
f512b905c7 Minimum xcode / sdk (#800)
* minimum xcode /sdk

* try multiple xcode versions in CI

* update python

* metal validation for python tests
2024-03-07 08:19:43 -08:00
Awni Hannun
afd5274049 route to fallback for bfloat (#794) 2024-03-06 15:39:12 -08:00
Awni Hannun
1074674e32 Add a maximum graph depth (#797)
* add a maximum graph depth

* remember how to use C++
2024-03-06 15:39:00 -08:00
AlexCheema
7762e07fde Update function_transforms.rst (#796)
Fix typo in function_transforms.rst
2024-03-06 12:03:37 -08:00
262 changed files with 22178 additions and 8820 deletions

View File

@@ -31,8 +31,7 @@ jobs:
name: Install dependencies name: Install dependencies
command: | command: |
pip install --upgrade cmake pip install --upgrade cmake
pip install --upgrade pybind11[global] pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install pybind11-stubgen
pip install numpy pip install numpy
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
@@ -44,7 +43,8 @@ jobs:
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
python3 setup.py generate_stubs echo "stubs"
python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
@@ -63,21 +63,24 @@ jobs:
command: ./build/tests/tests command: ./build/tests/tests
mac_build_and_test: mac_build_and_test:
parameters:
xcode_version:
type: string
default: "15.2.0"
macos: macos:
xcode: "15.2.0" xcode: << parameters.xcode_version >>
resource_class: macos.m1.large.gen1 resource_class: macos.m1.large.gen1
steps: steps:
- checkout - checkout
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
brew install python@3.9 brew install python@3.8
python3.9 -m venv env python3.8 -m venv env
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install --upgrade pybind11[global] pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install pybind11-stubgen
pip install numpy pip install numpy
pip install torch pip install torch
pip install tensorflow pip install tensorflow
@@ -91,13 +94,13 @@ jobs:
name: Generate package stubs name: Generate package stubs
command: | command: |
source env/bin/activate source env/bin/activate
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
source env/bin/activate source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
# TODO: Reenable when extension api becomes stable # TODO: Reenable when extension api becomes stable
# - run: # - run:
# name: Build example extension # name: Build example extension
@@ -140,9 +143,8 @@ jobs:
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install --upgrade pybind11[global] pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade setuptools pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy pip install numpy
pip install twine pip install twine
pip install build pip install build
@@ -157,7 +159,7 @@ jobs:
name: Generate package stubs name: Generate package stubs
command: | command: |
source env/bin/activate source env/bin/activate
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Build Python package name: Build Python package
command: | command: |
@@ -205,9 +207,8 @@ jobs:
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install --upgrade pybind11[global] pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade setuptools pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy pip install numpy
pip install auditwheel pip install auditwheel
pip install patchelf pip install patchelf
@@ -215,7 +216,7 @@ jobs:
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v pip install . -v
python setup.py generate_stubs python setup.py generate_stubs
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build --wheel python -m build --wheel
@@ -235,7 +236,10 @@ workflows:
- not: << pipeline.parameters.weekly_build >> - not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >> - not: << pipeline.parameters.test_release >>
jobs: jobs:
- mac_build_and_test - mac_build_and_test:
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0"]
- linux_build_and_test - linux_build_and_test
build_pypi_release: build_pypi_release:
@@ -254,7 +258,7 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["14.3.1", "15.2.0"] xcode_version: ["15.0.0", "15.2.0"]
build_env: ["PYPI_RELEASE=1"] build_env: ["PYPI_RELEASE=1"]
prb: prb:
when: when:
@@ -268,6 +272,9 @@ workflows:
context: pr-approval context: pr-approval
- mac_build_and_test: - mac_build_and_test:
requires: [ hold ] requires: [ hold ]
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0"]
- linux_build_and_test: - linux_build_and_test:
requires: [ hold ] requires: [ hold ]
nightly_build: nightly_build:
@@ -280,7 +287,7 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["14.3.1", "15.2.0"] xcode_version: ["15.0.0", "15.2.0"]
weekly_build: weekly_build:
when: when:
and: and:
@@ -291,7 +298,7 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["14.3.1", "15.2.0"] xcode_version: ["15.0.0", "15.2.0"]
build_env: ["DEV_RELEASE=1"] build_env: ["DEV_RELEASE=1"]
linux_test_release: linux_test_release:
when: when:

View File

@@ -1,11 +1,11 @@
repos: repos:
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v17.0.6 rev: v18.1.4
hooks: hooks:
- id: clang-format - id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster # Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.2.0 rev: 24.4.2
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort

View File

@@ -15,6 +15,8 @@ MLX was developed with contributions from the following individuals:
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays. - Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a> </a>

View File

@@ -15,31 +15,33 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF) option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION) if(NOT MLX_VERSION)
set(MLX_VERSION 0.5.1) set(MLX_VERSION 0.12.2)
endif() endif()
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}") message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
set(MLX_BUILD_ARM OFF) set(MLX_BUILD_ARM OFF)
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE}) if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
message(FATAL_ERROR if(NOT MLX_ENABLE_X64_MAC)
"Building for x86_64 on macOS is not supported." message(FATAL_ERROR
" If you are on an Apple silicon system, check the build" "Building for x86_64 on macOS is not supported."
" documentation for possible fixes: " " If you are on an Apple silicon system, check the build"
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source") " documentation for possible fixes: "
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64") "https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
message(WARNING else()
"Building for x86_64 on macOS is not supported." message(WARNING "Building for x86_64 arch is not officially supported.")
" If you are on an Apple silicon system, " endif()
" make sure you are building for arm64.") set(MLX_BUILD_METAL OFF)
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64") elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON) set(MLX_BUILD_ARM ON)
endif() endif()
@@ -64,8 +66,14 @@ endif()
if (MLX_BUILD_METAL AND NOT METAL_LIB) if (MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU") message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF) set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif (MLX_BUILD_METAL) elseif (MLX_BUILD_METAL)
message(STATUS "Building METAL sources") message(STATUS "Building METAL sources")
if (MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_VERSION OUTPUT_VARIABLE MACOS_VERSION
@@ -74,18 +82,19 @@ elseif (MLX_BUILD_METAL)
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}") message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
if (${MACOS_VERSION} GREATER_EQUAL 14.2) if (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip) set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0) elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip) set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
else() else()
message(FATAL_ERROR "MLX requires macOS >= 13.4 to be built with MLX_BUILD_METAL=ON" ) message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif() endif()
FetchContent_Declare( FetchContent_Declare(
metal_cpp metal_cpp
URL ${METAL_CPP_URL} URL ${METAL_CPP_URL}
PATCH_COMMAND patch -N -i ${METAL_CPP_PATCH} || true
) )
FetchContent_MakeAvailable(metal_cpp) FetchContent_MakeAvailable(metal_cpp)
@@ -110,7 +119,27 @@ if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
else() else()
message(STATUS "Accelerate or arm neon not found, using default backend.") message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
#set(BLA_VENDOR Generic) if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif()
# Search and link with lapack.
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version
# of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED) find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND) if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed") message(FATAL_ERROR "Must have BLAS installed")
@@ -124,17 +153,6 @@ else()
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES}) target_link_libraries(mlx ${BLAS_LIBRARIES})
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
endif() endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
@@ -148,8 +166,12 @@ target_include_directories(
if (MLX_BUILD_PYTHON_BINDINGS) if (MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.") message(STATUS "Building Python bindings.")
find_package(Python COMPONENTS Interpreter Development) find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
find_package(pybind11 CONFIG REQUIRED) execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif() endif()

View File

@@ -17,14 +17,13 @@
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \ << std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl; << std::endl;
#define TIMEM(MSG, FUNC, ...) \ #define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " \ std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \ << std::flush << std::setprecision(5) \
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \ << time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
<< std::endl;
template <typename F, typename... Args> template <typename F, typename... Args>
double time_fn(F fn, Args... args) { double time_fn(F fn, Args&&... args) {
// warmup // warmup
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
eval(fn(std::forward<Args>(args)...)); eval(fn(std::forward<Args>(args)...));

View File

@@ -0,0 +1,123 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_1D(strides=1, padding=0, groups=1):
def mx_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_1D
def make_pt_conv_1D(strides=1, padding=0, groups=1):
@torch.no_grad()
def pt_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_1D
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
scale = 1.0 / math.sqrt(wH * C)
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_1D(strides, padding, groups)
f_pt = make_pt_conv_1D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv1d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 5, 32, 1, 2, 1),
(4, 32, 32, 5, 32, 1, 2, 2),
(4, 32, 32, 5, 32, 1, 2, 4),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 16),
(4, 32, 32, 5, 32, 1, 2, 32),
(4, 32, 256, 5, 512, 1, 2, 2),
(4, 32, 256, 5, 512, 1, 2, 128),
(4, 32, 256, 5, 512, 1, 2, 256),
)
for dtype in dtypes:
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
for N, iH, C, wH, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, iH, C, wH, O, strides, padding, np_dtype, groups
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,57 @@
# Copyright © 2024 Apple Inc.
import matplotlib
import mlx.core as mx
import numpy as np
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def bandwidth_gb(runtime_ms, system_size):
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
bytes_per_gb = 1e9
ms_per_s = 1e3
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
def run_bench(system_size):
def fft(x):
out = mx.fft.fft(x)
mx.eval(out)
return out
bandwidths = []
for k in range(4, 12):
n = 2**k
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
x = x.astype(mx.complex64)
mx.eval(x)
runtime_ms = measure_runtime(fft, x=x)
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
return bandwidths
def time_fft():
with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=int(2**22))
with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=int(2**29))
# plot bandwidths
x = [2**k for k in range(4, 12)]
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
plt.title("MLX FFT Benchmark")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig("fft_plot.png")
if __name__ == "__main__":
time_fft()

View File

@@ -0,0 +1,41 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def layer_norm(x, w, b, eps):
ot = x.dtype
x = x.astype(mx.float32)
mu = mx.mean(x, -1, keepdims=True)
v = mx.var(x, -1, keepdims=True)
return (x - mu) * mx.rsqrt(v + eps) * w + b
def time_layer_norm():
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, b, y)
def layer_norm_loop(g, x, w, b):
gx, gw, gb = x, w, b
for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb
time_fn(layer_norm_loop, g1, x, w, b)
time_fn(layer_norm_loop, g2, x, w, b)
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
if __name__ == "__main__":
time_layer_norm()

View File

@@ -0,0 +1,39 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def rms_norm(x, w, eps):
ot = x.dtype
x = x.astype(mx.float32)
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return (x * n).astype(ot) * w
def time_rms_norm():
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1))
g2 = mx.grad(f2, argnums=(0, 1))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x, w):
gx, gw = x, w
for _ in range(32):
gx, gw = g(gx, gw, y)
return gx, gw
time_fn(rms_norm_loop, g1, x, w)
time_fn(rms_norm_loop, g2, x, w)
time_fn(rms_norm_loop, mx.compile(g1), x, w)
time_fn(rms_norm_loop, mx.compile(g2), x, w)
if __name__ == "__main__":
time_rms_norm()

View File

@@ -6,21 +6,21 @@ from time_utils import time_fn
def time_rope(): def time_rope():
rope = nn.RoPE(4096) rope = nn.RoPE(64)
# vec # vec
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16) x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
mx.eval(x) mx.eval(x)
def rope_vec(x): def rope_vec(x):
for _ in range(32): for _ in range(32):
x = rope(x) x = rope(x, offset=100)
return x return x
time_fn(rope_vec, x) time_fn(rope_vec, x)
# matrix # matrix
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16) x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
mx.eval(x) mx.eval(x)
def rope_mat(x): def rope_mat(x):

36
cmake/metal.14.0.diff Normal file
View File

@@ -0,0 +1,36 @@
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
--- Metal/MTLEvent.hpp 2023-06-01 12:18:26
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59
@@ -62,6 +62,7 @@
uint64_t signaledValue() const;
void setSignaledValue(uint64_t signaledValue);
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
};
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
@@ -138,6 +139,11 @@
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
+}
+
+// method: waitUntilSignaledValue
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
}
// static method: alloc
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
--- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29
@@ -1906,6 +1906,9 @@
"setShouldMaximizeConcurrentCompilation:");
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
"setSignaledValue:");
+_MTL_PRIVATE_DEF_SEL(
+ waitUntilSignaledValue_timeoutMS_,
+ "waitUntilSignaledValue:timeoutMS:");
_MTL_PRIVATE_DEF_SEL(setSize_,
"setSize:");
_MTL_PRIVATE_DEF_SEL(setSlice_,

36
cmake/metal.14.2.diff Normal file
View File

@@ -0,0 +1,36 @@
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
--- Metal/MTLEvent.hpp 2024-04-15 07:12:10
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50
@@ -62,6 +62,7 @@
uint64_t signaledValue() const;
void setSignaledValue(uint64_t signaledValue);
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
};
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
@@ -138,6 +139,11 @@
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
+}
+
+// method: waitUntilSignaledValue
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
}
// static method: alloc
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
--- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15
@@ -1918,6 +1918,9 @@
"setShouldMaximizeConcurrentCompilation:");
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
"setSignaledValue:");
+_MTL_PRIVATE_DEF_SEL(
+ waitUntilSignaledValue_timeoutMS_,
+ "waitUntilSignaledValue:timeoutMS:");
_MTL_PRIVATE_DEF_SEL(setSize_,
"setSize:");
_MTL_PRIVATE_DEF_SEL(setSlice_,

50
docs/Doxyfile Normal file
View File

@@ -0,0 +1,50 @@
################################################################################
# Primary project setup. #
################################################################################
PROJECT_NAME = "MLX"
OUTPUT_DIRECTORY = build
XML_OUTPUT = xml
HTML_OUTPUT = html
STRIP_FROM_PATH = ../
INPUT = ../mlx
FILE_PATTERNS = *.h
EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = YES
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES
################################################################################
# Doxygen preprocessor / parser control. #
################################################################################
ENABLE_PREPROCESSING = YES
MACRO_EXPANSION = YES
EXPAND_ONLY_PREDEF = NO
SKIP_FUNCTION_MACROS = NO
################################################################################
# Compound extraction control. #
################################################################################
EXTRACT_ALL = YES
EXTRACT_PACKAGE = YES
EXTRACT_STATIC = YES
CASE_SENSE_NAMES = NO
################################################################################
# Docstring control / customization. #
################################################################################
JAVADOC_AUTOBRIEF = YES
################################################################################
# Warning suppression. #
################################################################################
QUIET = YES
WARN_IF_UNDOCUMENTED = NO

View File

@@ -2,12 +2,16 @@
### Setup (do once) ### Setup (do once)
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html) Install Doxygen:
for example with `conda`:
``` ```
conda install sphinx brew install doxygen
pip install sphinx-book-theme ```
Install Python packages:
```
pip install -r requirements.txt
``` ```
### Build ### Build
@@ -15,7 +19,7 @@ pip install sphinx-book-theme
Build the docs from `mlx/docs/` Build the docs from `mlx/docs/`
``` ```
make html doxygen && make html
``` ```
View the docs by running a server in `mlx/docs/build/html/`: View the docs by running a server in `mlx/docs/build/html/`:

3
docs/requirements.txt Normal file
View File

@@ -0,0 +1,3 @@
sphinx
breathe
sphinx-book-theme

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 746 KiB

View File

@@ -0,0 +1,20 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. autoclass:: {{ objname }}
{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != "__init__" %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}

View File

@@ -22,6 +22,7 @@ extensions = [
"sphinx.ext.autosummary", "sphinx.ext.autosummary",
"sphinx.ext.intersphinx", "sphinx.ext.intersphinx",
"sphinx.ext.napoleon", "sphinx.ext.napoleon",
"breathe",
] ]
python_use_unqualified_type_names = True python_use_unqualified_type_names = True
@@ -29,16 +30,20 @@ autosummary_generate = True
autosummary_filename_map = {"mlx.core.Stream": "stream_class"} autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
intersphinx_mapping = { intersphinx_mapping = {
"https://docs.python.org/3": None, "python": ("https://docs.python.org/3", None),
"https://numpy.org/doc/stable/": None, "numpy": ("https://numpy.org/doc/stable/", None),
} }
breathe_projects = {"mlx": "../build/xml"}
breathe_default_project = "mlx"
templates_path = ["_templates"] templates_path = ["_templates"]
html_static_path = ["_static"] html_static_path = ["_static"]
source_suffix = ".rst" source_suffix = ".rst"
master_doc = "index" main_doc = "index"
highlight_language = "python" highlight_language = "python"
pygments_style = "sphinx" pygments_style = "sphinx"
add_module_names = False
# -- Options for HTML output ------------------------------------------------- # -- Options for HTML output -------------------------------------------------
@@ -59,3 +64,22 @@ html_theme_options = {
# -- Options for HTMLHelp output --------------------------------------------- # -- Options for HTMLHelp output ---------------------------------------------
htmlhelp_basename = "mlx_doc" htmlhelp_basename = "mlx_doc"
def setup(app):
from sphinx.util import inspect
wrapped_isfunc = inspect.isfunction
def isfunc(obj):
type_name = str(type(obj))
if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name:
return True
return wrapped_isfunc(obj)
inspect.isfunction = isfunc
# -- Options for LaTeX output ------------------------------------------------
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]

View File

@@ -3,4 +3,5 @@
Operations Operations
========== ==========
.. doxygengroup:: ops
:content-only:

View File

@@ -1,24 +1,16 @@
Developer Documentation Developer Documentation
======================= =======================
MLX provides a open and flexible backend to which users may add operations You can extend MLX with custom operations on the CPU or GPU. This guide
and specialized implementations without much hassle. While the library supplies explains how to do that with a simple example.
efficient operations that can be used and composed for any number of
applications, there may arise cases where new functionalities or highly
optimized implementations are needed. For such cases, you may design and
implement your own operations that link to and build on top of :mod:`mlx.core`.
We will introduce the inner-workings of MLX and go over a simple example to
learn the steps involved in adding new operations to MLX with your own CPU
and GPU implementations.
Introducing the Example Introducing the Example
----------------------- -----------------------
Let's say that you would like an operation that takes in two arrays, Let's say you would like an operation that takes in two arrays, ``x`` and
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta`` ``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
respectively, and then adds them together to get the result and then adds them together to get the result ``z = alpha * x + beta * y``.
``z = alpha * x + beta * y``. Well, you can very easily do that by just You can do that in MLX directly:
writing out a function as follows:
.. code-block:: python .. code-block:: python
@@ -27,44 +19,35 @@ writing out a function as follows:
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array: def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y return alpha * x + beta * y
This function performs that operation while leaving the implementations and This function performs that operation while leaving the implementation and
differentiation to MLX. function transformations to MLX.
However, you work with vector math libraries often and realize that the However you may need to customize the underlying implementation, perhaps to
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``. make it faster or for custom differentiation. In this tutorial we will go
You would really like the part of your applications that does this operation through adding custom extensions. It will cover:
on the CPU to be very fast - so you decide that you want it to rely on the
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
our assumptions on to you, let's also assume that you want to learn how to add
your own implementation for the gradients of your new operation while going
over the ins-and-outs of the MLX framework.
Well, what a coincidence! You are in the right place. Over the course of this * The structure of the MLX library.
example, we will learn: * Implementing a CPU operation that redirects to Accelerate_ when appropriate.
* Implementing a GPU operation using metal.
* The structure of the MLX library from the frontend API to the backend implementations. * Adding the ``vjp`` and ``jvp`` function transformation.
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed). * Building a custom extension and binding it to python.
* How to implement your own GPU implementation using metal.
* How to add your own ``vjp`` and ``jvp``.
* How to build your implementations, link them to MLX, and bind them to python.
Operations and Primitives Operations and Primitives
------------------------- -------------------------
In one sentence, operations in MLX build the computation graph, and primitives Operations in MLX build the computation graph. Primitives provide the rules for
provide the rules for evaluation and transformations of said graph. Let's start evaluating and transforming the graph. Let's start by discussing operations in
by discussing operations in more detail. more detail.
Operations Operations
^^^^^^^^^^^ ^^^^^^^^^^^
Operations are the frontend functions that operate on arrays. They are defined Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
operations in the Python API (:ref:`ops`).
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``, We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the ``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
C++ API: C++:
.. code-block:: C++ .. code-block:: C++
@@ -83,10 +66,7 @@ C++ API:
StreamOrDevice s = {} // Stream on which to schedule the operation StreamOrDevice s = {} // Stream on which to schedule the operation
); );
The simplest way to this operation is in terms of existing operations:
This operation itself can call other operations within it if needed. So, the
simplest way to go about implementing this operation would be do so in terms
of existing operations.
.. code-block:: C++ .. code-block:: C++
@@ -100,25 +80,23 @@ of existing operations.
// Scale x and y on the provided stream // Scale x and y on the provided stream
auto ax = multiply(array(alpha), x, s); auto ax = multiply(array(alpha), x, s);
auto by = multiply(array(beta), y, s); auto by = multiply(array(beta), y, s);
// Add and return // Add and return
return add(ax, by, s); return add(ax, by, s);
} }
However, as we discussed earlier, this is not our goal. The operations themselves The operations themselves do not contain the implementations that act on the
do not contain the implementations that act on the data, nor do they contain the data, nor do they contain the rules of transformations. Rather, they are an
rules of transformations. Rather, they are an easy to use interface that build easy to use interface that use :class:`Primitive` building blocks.
on top of the building blocks we call :class:`Primitive`.
Primitives Primitives
^^^^^^^^^^^ ^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create an output given a set of input :class:`array` . Further, defines how to create outputs arrays given a input arrays. Further, a
a :class:`Primitive` is a class that contains rules on how it is evaluated :class:`Primitive` has methods to run on the CPU or GPU and for function
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
``jvp``. These words on their own can be a bit abstract, so lets take a step more concrete:
back and go to our example to give ourselves a more concrete image.
.. code-block:: C++ .. code-block:: C++
@@ -134,11 +112,15 @@ back and go to our example to give ourselves a more concrete image.
* To avoid unnecessary allocations, the evaluation function * To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array. * is responsible for allocating space for the array.
*/ */
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(
void eval_gpu(const std::vector<array>& inputs, array& out) override; const std::vector<array>& inputs,
std::vector<array>& outputs) override;
void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
/** The Jacobian-vector product. */ /** The Jacobian-vector product. */
array jvp( std::vector<array> jvp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) override; const std::vector<int>& argnums) override;
@@ -147,7 +129,8 @@ back and go to our example to give ourselves a more concrete image.
std::vector<array> vjp( std::vector<array> vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const array& cotan, const array& cotan,
const std::vector<int>& argnums) override; const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
/** /**
* The primitive must know how to vectorize itself across * The primitive must know how to vectorize itself across
@@ -155,7 +138,7 @@ back and go to our example to give ourselves a more concrete image.
* representing the vectorized computation and the axis which * representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension. * corresponds to the output vectorized dimension.
*/ */
std::pair<array, int> vmap( virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) override; const std::vector<int>& axes) override;
@@ -175,22 +158,22 @@ back and go to our example to give ourselves a more concrete image.
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
The :class:`Axpby` class derives from the base :class:`Primitive` class and The :class:`Axpby` class derives from the base :class:`Primitive` class. The
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and :class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
``beta`` as parameters. It then provides implementations of how the array ``out`` implementations of how the output array is produced given the inputs through
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`. :meth:`Axpby::vmap`.
Using the Primitives Using the Primitive
^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^
Operations can use this :class:`Primitive` to add a new :class:`array` to Operations can use this :class:`Primitive` to add a new :class:`array` to the
the computation graph. An :class:`array` can be constructed by providing its computation graph. An :class:`array` can be constructed by providing its data
data type, shape, the :class:`Primitive` that computes it, and the type, shape, the :class:`Primitive` that computes it, and the :class:`array`
:class:`array` inputs that are passed to the primitive. inputs that are passed to the primitive.
Let's re-implement our operation now in terms of our :class:`Axpby` primitive. Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
.. code-block:: C++ .. code-block:: C++
@@ -223,7 +206,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
/* const std::vector<int>& shape = */ out_shape, /* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype, /* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */ /* std::unique_ptr<Primitive> primitive = */
std::make_unique<Axpby>(to_stream(s), alpha, beta), std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs); /* const std::vector<array>& inputs = */ broadcasted_inputs);
} }
@@ -238,27 +221,26 @@ This operation now handles the following:
Implementing the Primitive Implementing the Primitive
-------------------------- --------------------------
No computation happens when we call the operation alone. In effect, the No computation happens when we call the operation alone. The operation only
operation only builds the computation graph. When we evaluate the output builds the computation graph. When we evaluate the output array, MLX schedules
array, MLX schedules the execution of the computation graph, and calls the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the :meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
stream/device specified by the user.
.. warning:: .. warning::
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called, When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
no memory has been allocated for the output array. It falls on the implementation no memory has been allocated for the output array. It falls on the implementation
of these functions to allocate memory as needed of these functions to allocate memory as needed.
Implementing the CPU Backend Implementing the CPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by trying to implement a naive and generic version of Let's start by implementing a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of :meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`. :class:`Axpby` earlier called :meth:`Axpby::eval`.
Our naive method will go over each element of the output array, find the Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation corresponding input elements of ``x`` and ``y`` and perform the operation
pointwise. This is captured in the templated function :meth:`axpby_impl`. point-wise. This is captured in the templated function :meth:`axpby_impl`.
.. code-block:: C++ .. code-block:: C++
@@ -296,19 +278,19 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
} }
} }
Now, we would like our implementation to be able to do this pointwise operation Our implementation should work for all incoming floating point arrays.
for all incoming floating point arrays. Accordingly, we add dispatches for Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error ``complex64``. We throw an error if we encounter an unexpected type.
if we encounter an unexpected type.
.. code-block:: C++ .. code-block:: C++
/** Fall back implementation for evaluation on CPU */ /** Fall back implementation for evaluation on CPU */
void Axpby::eval(const std::vector<array>& inputs, array& out) { void Axpby::eval(
// Check the inputs (registered in the op while constructing the out array) const std::vector<array>& inputs,
assert(inputs.size() == 2); const std::vector<array>& outputs) {
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype // Dispatch to the correct dtype
if (out.dtype() == float32) { if (out.dtype() == float32) {
@@ -321,28 +303,26 @@ if we encounter an unexpected type.
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_); return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else { } else {
throw std::runtime_error( throw std::runtime_error(
"Axpby is only supported for floating point types."); "[Axpby] Only supports floating point types.");
} }
} }
We have a fallback implementation! Now, to do what we are really here to do. This is good as a fallback implementation. We can use the ``axpby`` routine
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_ provided by the Accelerate_ framework for a faster implementation in certain
framework? Well, there are 3 complications to keep in mind: cases:
#. Accelerate does not provide implementations of ``axpby`` for half precision #. Accelerate does not provide implementations of ``axpby`` for half precision
floats. We can only direct to it for ``float32`` types floats. We can only use it for ``float32`` types.
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements #. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
have fixed strides between them. Possibly due to broadcasts and transposes, elements have fixed strides between them. We only direct to Accelerate
we aren't guaranteed that the inputs fit this requirement. We can if both ``x`` and ``y`` are row contiguous or column contiguous.
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or #. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
column contiguous. MLX expects to write the output to a new array. We must copy the elements
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace. of ``y`` into the output and use that as an input to ``axpby``.
MLX expects to write out the answer to a new array. We must copy the elements
of ``y`` into the output array and use that as an input to ``axpby``
Let's write out an implementation that uses Accelerate in the right conditions. Let's write an implementation that uses Accelerate in the right conditions.
It must simply allocate data for the output, copy elements of ``y`` into it, It allocates data for the output, copies ``y`` into it, and then calls the
and then call the :meth:`catlas_saxpby` from accelerate. :func:`catlas_saxpby` from accelerate.
.. code-block:: C++ .. code-block:: C++
@@ -356,17 +336,7 @@ and then call the :meth:`catlas_saxpby` from accelerate.
// Accelerate library provides catlas_saxpby which does // Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place // Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array // To use it, we first copy the data in y over to the output array
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// This specialization requires both x and y be contiguous in the same mode
// i.e: corresponding linear indices in both point to corresponding elements
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
// We then copy over the elements using the contiguous vector specialization // We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector); copy_inplace(y, out, CopyType::Vector);
@@ -389,18 +359,20 @@ and then call the :meth:`catlas_saxpby` from accelerate.
/* INCY = */ 1); /* INCY = */ 1);
} }
Great! But what about the inputs that do not fit the criteria for accelerate? For inputs that do not fit the criteria for accelerate, we fall back to
Luckily, we can always just direct back to :meth:`Axpby::eval`. :meth:`Axpby::eval`. With this in mind, let's finish our
:meth:`Axpby::eval_cpu`.
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
.. code-block:: C++ .. code-block:: C++
/** Evaluate primitive on CPU using accelerate specializations */ /** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) { void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays // Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 && if (out.dtype() == float32 &&
@@ -410,35 +382,33 @@ With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
return; return;
} }
// Fall back to common backend if specializations are not available // Fall back to common back-end if specializations are not available
eval(inputs, out); eval(inputs, outputs);
} }
We have now hit a milestone! Just this much is enough to run the operation Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
:meth:`axpby` on a CPU stream! you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
If you do not plan on running the operation on the GPU or using transforms on Implementing the GPU Back-end
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
Implementing the GPU Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Apple silicon devices address their GPUs using the Metal_ shading language, and Apple silicon devices address their GPUs using the Metal_ shading language, and
all GPU kernels in MLX are written using metal. GPU kernels in MLX are written using Metal.
.. note:: .. note::
Here are some helpful resources if you are new to metal! Here are some helpful resources if you are new to Metal:
* A walkthrough of the metal compute pipeline: `Metal Example`_ * A walkthrough of the metal compute pipeline: `Metal Example`_
* Documentation for metal shading language: `Metal Specification`_ * Documentation for metal shading language: `Metal Specification`_
* Using metal from C++: `Metal-cpp`_ * Using metal from C++: `Metal-cpp`_
Let's keep the GPU algorithm simple. We will launch exactly as many threads Let's keep the GPU kernel simple. We will launch exactly as many threads as
as there are elements in the output. Each thread will pick the element it needs there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the pointwise operation, and then update its assigned from ``x`` and ``y``, do the point-wise operation, and update its assigned
element in the output. element in the output.
.. code-block:: C++ .. code-block:: C++
@@ -457,15 +427,14 @@ element in the output.
// Convert linear indices to offsets in array // Convert linear indices to offsets in array
auto x_offset = elem_to_loc(index, shape, x_strides, ndim); auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim); auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
// Do the operation and update the output // Do the operation and update the output
out[index] = out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset]; static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
} }
We then need to instantiate this template for all floating point types and give We then need to instantiate this template for all floating point types and give
each instantiation a unique host name so we can identify the right kernel for each instantiation a unique host name so we can identify it.
each data type.
.. code-block:: C++ .. code-block:: C++
@@ -488,29 +457,21 @@ each data type.
instantiate_axpby(bfloat16, bfloat16_t); instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t); instantiate_axpby(complex64, complex64_t);
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we The logic to determine the kernel, set the inputs, resolve the grid dimensions,
will see later in :ref:`Building with CMake`. In the following example, we and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
assume that the library ``mlx_ext.metallib`` will always be co-located with
the executable/ shared-library calling the :meth:`register_library` function.
The :meth:`register_library` function takes the library's name and potential
path (or in this case, a function that can produce the path of the metal
library) and tries to load that library if it hasn't already been registered
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
it is important to package your C++ library with the metal library. We will
go over this process in more detail later.
The logic to determine the kernel, set the inputs, resolve the grid dimensions
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
below. below.
.. code-block:: C++ .. code-block:: C++
/** Evaluate primitive on GPU */ /** Evaluate primitive on GPU */
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) { void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Prepare inputs // Prepare inputs
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on // Each primitive carries the stream it should execute on
// and each stream carries its device identifiers // and each stream carries its device identifiers
@@ -518,10 +479,10 @@ below.
// We get the needed metal device using the stream // We get the needed metal device using the stream
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
// Allocate output memory // Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Resolve name of kernel (corresponds to axpby.metal) // Resolve name of kernel
std::ostringstream kname; std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out); kname << "axpby_" << "general_" << type_to_name(out);
@@ -552,7 +513,7 @@ below.
compute_encoder->setBytes(&alpha_, sizeof(float), 3); compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4); compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim // Encode shape, strides and ndim
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5); compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6); compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7); compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
@@ -575,28 +536,25 @@ below.
We can now call the :meth:`axpby` operation on both the CPU and the GPU! We can now call the :meth:`axpby` operation on both the CPU and the GPU!
A few things to note about MLX and metal before moving on. MLX keeps track A few things to note about MLX and Metal before moving on. MLX keeps track of
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder` the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
to give us the active metal compute command encoder instead of building a associated. We rely on :meth:`d.get_command_encoder` to give us the active
new one and calling :meth:`compute_encoder->end_encoding` at the end. metal compute command encoder instead of building a new one and calling
MLX keeps adding kernels (compute pipelines) to the active command encoder :meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
until some specified limit is hit or the compute encoder needs to be flushed pipelines) to the active command buffer until some specified limit is hit or
for synchronization. MLX also handles enqueuing and committing the associated the command buffer needs to be flushed for synchronization.
command buffers as needed. We suggest taking a deeper dive into
:class:`metal::Device` if you would like to study this routine further.
Primitive Transforms Primitive Transforms
^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^
Now that we have come this far, let's also learn how to add implementations to Next, let's add implementations for transformations in a :class:`Primitive`.
transformations in a :class:`Primitive`. These transformations can be built on These transformations can be built on top of other operations, including the
top of our operations, including the one we just defined now. Which then gives one we just defined:
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
.. code-block:: C++ .. code-block:: C++
/** The Jacobian-vector product. */ /** The Jacobian-vector product. */
array Axpby::jvp( std::vector<array> Axpby::jvp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
@@ -611,12 +569,12 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
if (argnums.size() > 1) { if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_; auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype()); auto scale_arr = array(scale, tangents[0].dtype());
return multiply(scale_arr, tangents[0], stream()); return {multiply(scale_arr, tangents[0], stream())};
} }
// If, argnums = {0, 1}, we take contributions from both // If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta // which gives us jvp = tangent_x * alpha + tangent_y * beta
else { else {
return axpby(tangents[0], tangents[1], alpha_, beta_, stream()); return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
} }
} }
@@ -625,34 +583,35 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
/** The vector-Jacobian product. */ /** The vector-Jacobian product. */
std::vector<array> Axpby::vjp( std::vector<array> Axpby::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const array& cotan, const std::vector<array>& cotangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums,
const std::vector<int>& /* unused */) {
// Reverse mode diff // Reverse mode diff
std::vector<array> vjps; std::vector<array> vjps;
for (auto arg : argnums) { for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_; auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotan.dtype()); auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotan, stream())); vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
} }
return vjps; return vjps;
} }
Finally, you need not have a transformation fully defined to start using your Note, a transformation does not need to be fully defined to start using
own :class:`Primitive`. the :class:`Primitive`.
.. code-block:: C++ .. code-block:: C++
/** Vectorize primitive along given axis */ /** Vectorize primitive along given axis */
std::pair<array, int> Axpby::vmap( std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation."); throw std::runtime_error("[Axpby] vmap not implemented.");
} }
Building and Binding Building and Binding
-------------------- --------------------
Let's look at the overall directory structure first. Let's look at the overall directory structure first.
| extensions | extensions
| ├── axpby | ├── axpby
@@ -666,40 +625,39 @@ Let's look at the overall directory structure first.
| └── setup.py | └── setup.py
* ``extensions/axpby/`` defines the C++ extension library * ``extensions/axpby/`` defines the C++ extension library
* ``extensions/mlx_sample_extensions`` sets out the structure for the * ``extensions/mlx_sample_extensions`` sets out the structure for the
associated python package associated Python package
* ``extensions/bindings.cpp`` provides python bindings for our operation * ``extensions/bindings.cpp`` provides Python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and * ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
python bindings Python bindings
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install * ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
the python package the Python package
Binding to Python Binding to Python
^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^
We use PyBind11_ to build a Python API for the C++ library. Since bindings for We use nanobind_ to build a Python API for the C++ library. Since bindings for
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
already provided, adding our :meth:`axpby` is simple! already provided, adding our :meth:`axpby` is simple.
.. code-block:: C++ .. code-block:: C++
PYBIND11_MODULE(mlx_sample_extensions, m) { NB_MODULE(_ext, m) {
m.doc() = "Sample C++ and metal extensions for MLX"; m.doc() = "Sample extension for MLX";
m.def( m.def(
"axpby", "axpby",
&axpby, &axpby,
"x"_a, "x"_a,
"y"_a, "y"_a,
py::pos_only(),
"alpha"_a, "alpha"_a,
"beta"_a, "beta"_a,
py::kw_only(), nb::kw_only(),
"stream"_a = py::none(), "stream"_a = nb::none(),
R"pbdoc( R"(
Scale and sum two vectors element-wise Scale and sum two vectors element-wise
``z = alpha * x + beta * y`` ``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y`` Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed Inputs are upcasted to floats if needed
@@ -711,17 +669,17 @@ already provided, adding our :meth:`axpby` is simple!
Returns: Returns:
array: ``alpha * x + beta * y`` array: ``alpha * x + beta * y``
)pbdoc"); )");
} }
Most of the complexity in the above example comes from additional bells and Most of the complexity in the above example comes from additional bells and
whistles such as the literal names and doc-strings. whistles such as the literal names and doc-strings.
.. warning:: .. warning::
:mod:`mlx.core` needs to be imported before importing :mod:`mlx.core` must be imported before importing
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to :mod:`mlx_sample_extensions` as defined by the nanobind module above to
ensure that the casters for :mod:`mlx.core` components like ensure that the casters for :mod:`mlx.core` components like
:class:`mlx.core.array` are available. :class:`mlx.core.array` are available.
.. _Building with CMake: .. _Building with CMake:
@@ -729,8 +687,8 @@ whistles such as the literal names and doc-strings.
Building with CMake Building with CMake
^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^
Building the C++ extension library itself is simple, it only requires that you Building the C++ extension library only requires that you ``find_package(MLX
``find_package(MLX CONFIG)`` and then link it to your library. CONFIG)`` and then link it to your library.
.. code-block:: cmake .. code-block:: cmake
@@ -752,12 +710,12 @@ Building the C++ extension library itself is simple, it only requires that you
# Link to mlx # Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx) target_link_libraries(mlx_ext PUBLIC mlx)
We also need to build the attached metal library. For convenience, we provide a We also need to build the attached Metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given :meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package). automatically imported with MLX package).
Here is what that looks like in practice! Here is what that looks like in practice:
.. code-block:: cmake .. code-block:: cmake
@@ -779,27 +737,29 @@ Here is what that looks like in practice!
endif() endif()
Finally, we build the Pybind11_ bindings Finally, we build the nanobind_ bindings
.. code-block:: cmake .. code-block:: cmake
pybind11_add_module( nanobind_add_module(
mlx_sample_extensions _ext
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
) )
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext) target_link_libraries(_ext PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS) if(BUILD_SHARED_LIBS)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path) target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif() endif()
Building with ``setuptools`` Building with ``setuptools``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once we have set out the CMake build rules as described above, we can use the Once we have set out the CMake build rules as described above, we can use the
build utilities defined in :mod:`mlx.extension` for a simple build process. build utilities defined in :mod:`mlx.extension`:
.. code-block:: python .. code-block:: python
from mlx import extension from mlx import extension
from setuptools import setup from setuptools import setup
@@ -809,48 +769,50 @@ build utilities defined in :mod:`mlx.extension` for a simple build process.
name="mlx_sample_extensions", name="mlx_sample_extensions",
version="0.0.0", version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.", description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")], ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
cmdclass={"build_ext": extension.CMakeBuild}, cmdclass={"build_ext": extension.CMakeBuild},
packages = ["mlx_sample_extensions"], packages=["mlx_sample_extensions"],
package_dir = {"": "mlx_sample_extensions"}, package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]}, extras_require={"dev":[]},
zip_safe=False, zip_safe=False,
python_requires=">=3.7", python_requires=">=3.8",
) )
.. note:: .. note::
We treat ``extensions/mlx_sample_extensions`` as the package directory We treat ``extensions/mlx_sample_extensions`` as the package directory
even though it only contains a ``__init__.py`` to ensure the following: even though it only contains a ``__init__.py`` to ensure the following:
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
You can build inplace for development using * :mod:`mlx.core` must be imported before importing :mod:`_ext`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
To build the package, first install the build dependencies with ``pip install
-r requirements.txt``. You can then build inplace for development using
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``) ``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
This will result in a directory structure as follows: This results in the directory structure:
| extensions | extensions
| ├── mlx_sample_extensions | ├── mlx_sample_extensions
| │ ├── __init__.py | │ ├── __init__.py
| │ ├── libmlx_ext.dylib # C++ extension library | │ ├── libmlx_ext.dylib # C++ extension library
| │ ├── mlx_ext.metallib # Metal library | │ ├── mlx_ext.metallib # Metal library
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding | │ └── _ext.cpython-3x-darwin.so # Python Binding
| ... | ...
When you try to install using the command ``python -m pip install .`` When you try to install using the command ``python -m pip install .`` (in
(in ``extensions/``), the package will be installed with the same structure as ``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and metal library will be ``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
copied along with the python binding since they are specified as ``package_data``. copied along with the Python binding since they are specified as
``package_data``.
Usage Usage
----- -----
After installing the extension as described above, you should be able to simply After installing the extension as described above, you should be able to simply
import the python package and play with it as you would any other MLX operation! import the Python package and play with it as you would any other MLX operation.
Let's looks at a simple script and it's results! Let's look at a simple script and its results:
.. code-block:: python .. code-block:: python
@@ -874,12 +836,12 @@ Output:
c correctness: True c correctness: True
Results Results
^^^^^^^^^^^^^^^^ ^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we defined at first on the CPU. with the naive :meth:`simple_axpby` we first defined on the CPU.
.. code-block:: python .. code-block:: python
import mlx.core as mx import mlx.core as mx
from mlx_sample_extensions import axpby from mlx_sample_extensions import axpby
@@ -898,7 +860,7 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
alpha = 4.0 alpha = 4.0
beta = 2.0 beta = 2.0
mx.eval((x, y)) mx.eval(x, y)
def bench(f): def bench(f):
# Warm up # Warm up
@@ -919,30 +881,23 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s") print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
Results: The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
modest improvements right away!
.. code-block::
Simple axpby: 0.114 s | Custom axpby: 0.109 s
We see some modest improvements right away!
This operation is now good to be used to build other operations, in This operation is now good to be used to build other operations, in
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like :class:`mlx.nn.Module` calls, and also as a part of graph transformations like
:meth:`grad`! :meth:`grad`.
Scripts Scripts
------- -------
.. admonition:: Download the code .. admonition:: Download the code
The full example code is available in `mlx <code>`_. The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc .. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc .. _Metal: https://developer.apple.com/documentation/metal?language=objc
.. _Metal-cpp: https://developer.apple.com/metal/cpp/ .. _Metal-cpp: https://developer.apple.com/metal/cpp/
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf .. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc .. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/ .. _nanobind: https://nanobind.readthedocs.io/en/latest/

View File

@@ -0,0 +1,68 @@
Metal Debugger
==============
.. currentmodule:: mlx.core
Profiling is a key step for performance optimization. You can build MLX with
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
* Records source during Metal compilation, for later inspection while
debugging.
* Labels Metal objects such as command queues, improving capture readability.
To build with debugging enabled in Python prepend
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
work.
.. note::
To capture a GPU trace you must run the application with
``MTL_CAPTURE_ENABLED=1``.
.. code-block:: python
import mlx.core as mx
a = mx.random.uniform(shape=(512, 512))
b = mx.random.uniform(shape=(512, 512))
mx.eval(a, b)
trace_file = "mlx_trace.gputrace"
# Make sure to run with MTL_CAPTURE_ENABLED=1 and
# that the path trace_file does not already exist.
mx.metal.start_capture(trace_file)
for _ in range(10):
mx.eval(mx.add(a, b))
mx.metal.stop_capture()
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
has a great overview of all operations. Checkout the `Metal debugger
documentation`_ for more information.
.. image:: ../_static/metal_debugger/capture.png
:class: dark-light
Xcode Workflow
--------------
You can skip saving to a path by running within Xcode. First, generate an
Xcode project using CMake.
.. code-block::
mkdir build && cd build
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
open mlx.xcodeproj
Select the ``metal_capture`` example schema and run.
.. image:: ../_static/metal_debugger/schema.png
:class: dark-light
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger

View File

@@ -58,10 +58,12 @@ are the CPU and GPU.
:maxdepth: 1 :maxdepth: 1
python/array python/array
python/data_types
python/devices_and_streams python/devices_and_streams
python/ops python/ops
python/random python/random
python/transforms python/transforms
python/fast
python/fft python/fft
python/linalg python/linalg
python/metal python/metal
@@ -80,3 +82,4 @@ are the CPU and GPU.
:maxdepth: 1 :maxdepth: 1
dev/extensions dev/extensions
dev/metal_debugger

View File

@@ -15,10 +15,10 @@ To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon) - Using an M series chip (Apple silicon)
- Using a native Python >= 3.8 - Using a native Python >= 3.8
- macOS >= 13.3 - macOS >= 13.5
.. note:: .. note::
MLX is only available on devices running macOS >= 13.3 MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma) It is highly recommended to use macOS 14 (Sonoma)
@@ -54,7 +54,7 @@ Build Requirements
- A C++ compiler with C++17 support (e.g. Clang >= 5.0) - A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make`` - `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above) - Xcode >= 15.0 and macOS SDK >= 14.0
.. note:: .. note::
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
@@ -70,16 +70,13 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_ Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
.. code-block:: shell .. code-block:: shell
pip install "pybind11[global]" pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
conda install pybind11
brew install pybind11
Then simply build and install it using pip: Then simply build and install MLX using pip:
.. code-block:: shell .. code-block:: shell
@@ -123,7 +120,7 @@ Create a build directory and run CMake and make:
.. code-block:: shell .. code-block:: shell
mkdir -p build && cd build mkdir -p build && cd build
cmake .. && make -j cmake .. && make -j
Run tests with: Run tests with:
@@ -142,7 +139,7 @@ directory as the executable statically linked to ``libmlx.a`` or the
preprocessor constant ``METAL_PATH`` should be defined at build time and it preprocessor constant ``METAL_PATH`` should be defined at build time and it
should point to the path to the built metal library. should point to the path to the built metal library.
.. list-table:: Build Options .. list-table:: Build Options
:widths: 25 8 :widths: 25 8
:header-rows: 1 :header-rows: 1
@@ -158,19 +155,21 @@ should point to the path to the built metal library.
- ON - ON
* - MLX_BUILD_PYTHON_BINDINGS * - MLX_BUILD_PYTHON_BINDINGS
- OFF - OFF
* - MLX_METAL_DEBUG
- OFF
.. note:: .. note::
If you have multiple Xcode installations and wish to use If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the a specific one while building, you can do so by adding the
following environment variable before building following environment variable before building
.. code-block:: shell .. code-block:: shell
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/" export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
Further, you can use the following command to find out which Further, you can use the following command to find out which
macOS SDK will be used macOS SDK will be used
.. code-block:: shell .. code-block:: shell
@@ -202,7 +201,7 @@ Then set the active developer directory:
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
x86 Shell x86 Shell
~~~~~~~~~ ~~~~~~~~~
.. _build shell: .. _build shell:

View File

@@ -10,27 +10,38 @@ Array
array array
array.astype array.astype
array.at
array.item array.item
array.tolist array.tolist
array.dtype array.dtype
array.itemsize
array.nbytes
array.ndim array.ndim
array.shape array.shape
array.size array.size
Dtype
array.abs array.abs
array.all array.all
array.any array.any
array.argmax array.argmax
array.argmin array.argmin
array.cos array.cos
array.dtype array.cummax
array.cummin
array.cumprod
array.cumsum
array.diag
array.diagonal
array.exp array.exp
array.flatten
array.log array.log
array.log10
array.log1p array.log1p
array.log2
array.logsumexp array.logsumexp
array.max array.max
array.mean array.mean
array.min array.min
array.moveaxis
array.prod array.prod
array.reciprocal array.reciprocal
array.reshape array.reshape
@@ -40,6 +51,8 @@ Array
array.split array.split
array.sqrt array.sqrt
array.square array.square
array.squeeze
array.swapaxes
array.sum array.sum
array.transpose array.transpose
array.T array.T

View File

@@ -1,7 +1,5 @@
.. _data_types: .. _data_types:
:orphan:
Data Types Data Types
========== ==========
@@ -44,9 +42,27 @@ The default floating point type is ``float32`` and the default integer type is
* - ``int64`` * - ``int64``
- 8 - 8
- 64-bit signed integer - 64-bit signed integer
* - ``bfloat16``
- 2
- 16-bit brain float (e8, m7)
* - ``float16`` * - ``float16``
- 2 - 2
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_ - 16-bit IEEE float (e5, m10)
* - ``float32`` * - ``float32``
- 4 - 4
- 32-bit float - 32-bit float
* - ``complex64``
- 8
- 64-bit complex float
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
documentation for more information. Use :func:`issubdtype` to determine if one
``dtype`` (or category) is a subtype of another category.
.. autosummary::
:toctree: _autosummary
Dtype
DtypeCategory
issubdtype

View File

@@ -16,3 +16,4 @@ Devices and Streams
new_stream new_stream
set_default_stream set_default_stream
stream stream
synchronize

14
docs/src/python/fast.rst Normal file
View File

@@ -0,0 +1,14 @@
.. _fast:
Fast
====
.. currentmodule:: mlx.core.fast
.. autosummary::
:toctree: _autosummary
rms_norm
layer_norm
rope
scaled_dot_product_attention

View File

@@ -3,12 +3,16 @@ Metal
.. currentmodule:: mlx.core.metal .. currentmodule:: mlx.core.metal
.. autosummary:: .. autosummary::
:toctree: _autosummary :toctree: _autosummary
is_available is_available
device_info
get_active_memory get_active_memory
get_peak_memory get_peak_memory
get_cache_memory get_cache_memory
set_memory_limit set_memory_limit
set_cache_limit set_cache_limit
clear_cache
start_capture
stop_capture

View File

@@ -173,6 +173,7 @@ In detail:
:toctree: _autosummary :toctree: _autosummary
value_and_grad value_and_grad
quantize
.. toctree:: .. toctree::

View File

@@ -21,17 +21,21 @@ Layers
Embedding Embedding
GELU GELU
GroupNorm GroupNorm
GRU
InstanceNorm InstanceNorm
LayerNorm LayerNorm
Linear Linear
LSTM
MaxPool1d MaxPool1d
MaxPool2d MaxPool2d
Mish Mish
MultiHeadAttention MultiHeadAttention
PReLU PReLU
QuantizedEmbedding
QuantizedLinear QuantizedLinear
RMSNorm RMSNorm
ReLU ReLU
RNN
RoPE RoPE
SELU SELU
Sequential Sequential
@@ -40,4 +44,4 @@ Layers
Softshrink Softshrink
Step Step
Transformer Transformer
Upsample Upsample

View File

@@ -30,6 +30,7 @@ Module
Module.named_modules Module.named_modules
Module.parameters Module.parameters
Module.save_weights Module.save_weights
Module.set_dtype
Module.train Module.train
Module.trainable_parameters Module.trainable_parameters
Module.unfreeze Module.unfreeze

View File

@@ -5,13 +5,13 @@ Operations
.. currentmodule:: mlx.core .. currentmodule:: mlx.core
.. autosummary:: .. autosummary::
:toctree: _autosummary :toctree: _autosummary
abs abs
add add
all all
allclose allclose
any any
arange arange
arccos arccos
@@ -28,6 +28,11 @@ Operations
atleast_1d atleast_1d
atleast_2d atleast_2d
atleast_3d atleast_3d
bitwise_and
bitwise_or
bitwise_xor
block_masked_mm
block_sparse_mm
broadcast_to broadcast_to
ceil ceil
clip clip
@@ -38,6 +43,11 @@ Operations
conv_general conv_general
cos cos
cosh cosh
cummax
cummin
cumprod
cumsum
degrees
dequantize dequantize
diag diag
diagonal diagonal
@@ -47,6 +57,7 @@ Operations
erf erf
erfinv erfinv
exp exp
expm1
expand_dims expand_dims
eye eye
flatten flatten
@@ -58,10 +69,11 @@ Operations
identity identity
inner inner
isclose isclose
isnan
isposinf
isneginf
isinf isinf
isnan
isneginf
isposinf
left_shift
less less
less_equal less_equal
linspace linspace
@@ -79,11 +91,13 @@ Operations
max max
maximum maximum
mean mean
meshgrid
min min
minimum minimum
moveaxis moveaxis
multiply multiply
negative negative
not_equal
ones ones
ones_like ones_like
outer outer
@@ -92,9 +106,11 @@ Operations
prod prod
quantize quantize
quantized_matmul quantized_matmul
radians
reciprocal reciprocal
repeat repeat
reshape reshape
right_shift
round round
rsqrt rsqrt
save save
@@ -113,6 +129,7 @@ Operations
square square
squeeze squeeze
stack stack
std
stop_gradient stop_gradient
subtract subtract
sum sum

View File

@@ -38,6 +38,7 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
gumbel gumbel
key key
normal normal
multivariate_normal
randint randint
seed seed
split split

View File

@@ -19,3 +19,4 @@ return python trees will be using the default python ``dict``, ``list`` and
tree_flatten tree_flatten
tree_unflatten tree_unflatten
tree_map tree_map
tree_map_with_path

View File

@@ -40,7 +40,7 @@ getting higher order derivatives.
Any of the MLX function transformations can be composed in any order to any Any of the MLX function transformations can be composed in any order to any
depth. See the following sections for more information on :ref:`automatic depth. See the following sections for more information on :ref:`automatic
differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`. differentiation <auto diff>` and :ref:`automatic vectorization <vmap>`.
For more information on :func:`compile` see the :ref:`compile documentation <compile>`. For more information on :func:`compile` see the :ref:`compile documentation <compile>`.

View File

@@ -18,7 +18,7 @@ describe below.
Transforming Compute Graphs Transforming Compute Graphs
^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
Lazy evaluation let's us record a compute graph without actually doing any Lazy evaluation lets us record a compute graph without actually doing any
computations. This is useful for function transformations like :func:`grad` and computations. This is useful for function transformations like :func:`grad` and
:func:`vmap` and graph optimizations. :func:`vmap` and graph optimizations.

View File

@@ -49,7 +49,7 @@ it will be added. You can load the array with:
.. code-block:: shell .. code-block:: shell
>>> mx.load("array.npy", a) >>> mx.load("array.npy")
array([1], dtype=float32) array([1], dtype=float32)
Here's an example of saving several arrays to a single file: Here's an example of saving several arrays to a single file:

View File

@@ -8,3 +8,4 @@ endfunction(build_example)
build_example(tutorial.cpp) build_example(tutorial.cpp)
build_example(linear_regression.cpp) build_example(linear_regression.cpp)
build_example(logistic_regression.cpp) build_example(logistic_regression.cpp)
build_example(metal_capture.cpp)

View File

@@ -0,0 +1,31 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/mlx.h"
using namespace mlx::core;
int main() {
// To use Metal debugging and profiling:
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
// 2. Run with MTL_CAPTURE_ENABLED=1.
metal::start_capture("mlx_trace.gputrace");
// Start at index two because the default GPU and CPU streams have indices
// zero and one, respectively. This naming matches the label assigned to each
// stream's command queue.
auto s2 = new_stream(Device::gpu);
auto s3 = new_stream(Device::gpu);
auto a = arange(1.f, 10.f, 1.f, float32, s2);
auto b = arange(1.f, 10.f, 1.f, float32, s3);
auto x = add(a, a, s2);
auto y = add(b, b, s3);
// The multiply will happen on the default stream.
std::cout << multiply(x, y) << std::endl;
metal::stop_capture();
}

View File

@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.27) cmake_minimum_required(VERSION 3.27)
project(mlx_sample_extensions LANGUAGES CXX) project(_ext LANGUAGES CXX)
# ----------------------------- Setup ----------------------------- # ----------------------------- Setup -----------------------------
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
@@ -11,8 +11,12 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies ----------------------------- # ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED) find_package(MLX CONFIG REQUIRED)
find_package(Python COMPONENTS Interpreter Development) find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
find_package(pybind11 CONFIG REQUIRED) execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
# ----------------------------- Extensions ----------------------------- # ----------------------------- Extensions -----------------------------
@@ -38,7 +42,6 @@ target_link_libraries(mlx_ext PUBLIC mlx)
# Build metallib # Build metallib
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL)
mlx_build_metallib( mlx_build_metallib(
TARGET mlx_ext_metallib TARGET mlx_ext_metallib
TITLE mlx_ext TITLE mlx_ext
@@ -54,13 +57,15 @@ if(MLX_BUILD_METAL)
endif() endif()
# ----------------------------- Pybind ----------------------------- # ----------------------------- Python Bindings -----------------------------
pybind11_add_module( nanobind_add_module(
mlx_sample_extensions _ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
) )
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext) target_link_libraries(_ext PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS) if(BUILD_SHARED_LIBS)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path) target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif() endif()

View File

@@ -0,0 +1,18 @@
## Build the extensions
```
pip install -e .
```
For faster builds during development, you can also pre-install the requirements:
```
pip install -r requirements.txt
```
And then run:
```
python setup.py build_ext -j8 --inplace
```

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
@@ -43,7 +43,7 @@ array axpby(
auto promoted_dtype = promote_types(x.dtype(), y.dtype()); auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y // Upcast to float32 for non-floating point inputs x and y
auto out_dtype = is_floating_point(promoted_dtype) auto out_dtype = issubdtype(promoted_dtype, float32)
? promoted_dtype ? promoted_dtype
: promote_types(promoted_dtype, float32); : promote_types(promoted_dtype, float32);
@@ -61,7 +61,7 @@ array axpby(
/* const std::vector<int>& shape = */ out_shape, /* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype, /* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */ /* std::unique_ptr<Primitive> primitive = */
std::make_unique<Axpby>(to_stream(s), alpha, beta), std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs); /* const std::vector<array>& inputs = */ broadcasted_inputs);
} }
@@ -106,12 +106,12 @@ void axpby_impl(
/** Fall back implementation for evaluation on CPU */ /** Fall back implementation for evaluation on CPU */
void Axpby::eval( void Axpby::eval(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& out_arr) { std::vector<array>& outputs) {
auto out = out_arr[0];
// Check the inputs (registered in the op while constructing the out array) // Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype // Dispatch to the correct dtype
if (out.dtype() == float32) { if (out.dtype() == float32) {
@@ -150,11 +150,7 @@ void axpby_impl_accelerate(
// The data in the output array is allocated to match the strides in y // The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and // such that x, y, and out are contiguous in the same mode and
// no transposition is needed // no transposition is needed
out.set_data( out.set_data(allocator::malloc_or_wait(out.nbytes()));
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
// We then copy over the elements using the contiguous vector specialization // We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector); copy_inplace(y, out, CopyType::Vector);
@@ -180,11 +176,11 @@ void axpby_impl_accelerate(
/** Evaluate primitive on CPU using accelerate specializations */ /** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu( void Axpby::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outarr) { std::vector<array>& outputs) {
auto out = outarr[0];
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays // Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 && if (out.dtype() == float32 &&
@@ -195,7 +191,7 @@ void Axpby::eval_cpu(
} }
// Fall back to common backend if specializations are not available // Fall back to common backend if specializations are not available
eval(inputs, outarr); eval(inputs, outputs);
} }
#else // Accelerate not available #else // Accelerate not available
@@ -203,8 +199,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on CPU falling back to common backend */ /** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu( void Axpby::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& out) { const std::vector<array>& outputs) {
eval(inputs, out); eval(inputs, outputs);
} }
#endif #endif
@@ -218,12 +214,12 @@ void Axpby::eval_cpu(
/** Evaluate primitive on GPU */ /** Evaluate primitive on GPU */
void Axpby::eval_gpu( void Axpby::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outarr) { std::vector<array>& outputs) {
// Prepare inputs // Prepare inputs
auto out = outarr[0];
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on // Each primitive carries the stream it should execute on
// and each stream carries its device identifiers // and each stream carries its device identifiers
@@ -372,4 +368,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_; return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -33,7 +33,7 @@ array axpby(
class Axpby : public Primitive { class Axpby : public Primitive {
public: public:
explicit Axpby(Stream stream, float alpha, float beta) explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta){}; : Primitive(stream), alpha_(alpha), beta_(beta) {};
/** /**
* A primitive must know how to evaluate itself on the CPU/GPU * A primitive must know how to evaluate itself on the CPU/GPU
@@ -42,9 +42,9 @@ class Axpby : public Primitive {
* To avoid unnecessary allocations, the evaluation function * To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array. * is responsible for allocating space for the array.
*/ */
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& out) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& out) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
/** The Jacobian-vector product. */ /** The Jacobian-vector product. */
@@ -83,7 +83,7 @@ class Axpby : public Primitive {
float beta_; float beta_;
/** Fall back implementation for evaluation on CPU */ /** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, std::vector<array>& out); void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
}; };
} // namespace mlx::core } // namespace mlx::core

View File

@@ -19,7 +19,7 @@ template <typename T>
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto x_offset = elem_to_loc(index, shape, x_strides, ndim); auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim); auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
out[index] = out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset]; static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
} }
@@ -31,30 +31,30 @@ template <typename T>
constant const float& alpha [[buffer(3)]], constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]], constant const float& beta [[buffer(4)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
out[index] = out[index] =
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index]; static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
} }
#define instantiate_axpby(type_name, type) \ #define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] \ template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
[[kernel]] void axpby_general<type>( \ axpby_general<type>( \
device const type* x [[buffer(0)]], \ device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \ device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \ device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \ constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \ constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \ constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \ constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \ constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \ constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] \ template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
[[kernel]] void axpby_contiguous<type>( \ axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \ device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \ device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \ device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \ constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \ constant const float& beta [[buffer(4)]], \
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float); instantiate_axpby(float32, float);

View File

@@ -1,31 +1,31 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h> #include <nanobind/nanobind.h>
#include <pybind11/stl.h> #include <nanobind/stl/variant.h>
#include "axpby/axpby.h" #include "axpby/axpby.h"
namespace py = pybind11; namespace nb = nanobind;
using namespace py::literals; using namespace nb::literals;
using namespace mlx::core; using namespace mlx::core;
PYBIND11_MODULE(mlx_sample_extensions, m) { NB_MODULE(_ext, m) {
m.doc() = "Sample C++ and metal extensions for MLX"; m.doc() = "Sample extension for MLX";
m.def( m.def(
"axpby", "axpby",
&axpby, &axpby,
"x"_a, "x"_a,
"y"_a, "y"_a,
py::pos_only(),
"alpha"_a, "alpha"_a,
"beta"_a, "beta"_a,
py::kw_only(), nb::kw_only(),
"stream"_a = py::none(), "stream"_a = nb::none(),
R"pbdoc( R"(
Scale and sum two vectors element-wise Scale and sum two vectors element-wise
``z = alpha * x + beta * y`` ``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y`` Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed Inputs are upcasted to floats if needed
@@ -37,5 +37,5 @@ PYBIND11_MODULE(mlx_sample_extensions, m) {
Returns: Returns:
array: ``alpha * x + beta * y`` array: ``alpha * x + beta * y``
)pbdoc"); )");
} }

View File

@@ -1,3 +1,8 @@
[build-system] [build-system]
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24", "mlx @ git+https://github.com/mlx-explore/mlx@main"] requires = [
build-backend = "setuptools.build_meta" "setuptools>=42",
"cmake>=3.24",
"mlx>=0.9.0",
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
]
build-backend = "setuptools.build_meta"

View File

@@ -0,0 +1,4 @@
setuptools>=42
cmake>=3.24
mlx>=0.9.0
nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023-2024 Apple Inc.
from setuptools import setup from setuptools import setup
@@ -9,11 +9,11 @@ if __name__ == "__main__":
name="mlx_sample_extensions", name="mlx_sample_extensions",
version="0.0.0", version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.", description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")], ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
cmdclass={"build_ext": extension.CMakeBuild}, cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"], packages=["mlx_sample_extensions"],
package_dir={"": "."},
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]}, package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev": []},
zip_safe=False, zip_safe=False,
python_requires=">=3.8", python_requires=">=3.8",
) )

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_; void* ptr_;
public: public:
Buffer(void* ptr) : ptr_(ptr){}; Buffer(void* ptr) : ptr_(ptr) {};
// Get the raw data pointer from the buffer // Get the raw data pointer from the buffer
void* raw_ptr(); void* raw_ptr();

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <functional> #include <functional>
#include "mlx/array.h" #include "mlx/array.h"
@@ -12,16 +11,6 @@ namespace mlx::core {
namespace { namespace {
std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
std::vector<size_t> strides(shape.size());
size_t cum_prod = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = cum_prod;
cum_prod *= shape[i];
}
return {cum_prod, strides};
}
/** Return true if we are currently performing a function transformation in /** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */ * order to keep the graph when evaluating tracer arrays. */
bool in_tracing() { bool in_tracing() {
@@ -36,22 +25,11 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
init(&cval); init(&cval);
} }
array::array(
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
shape,
dtype,
std::move(primitive),
inputs)) {}
array::array( array::array(
std::vector<int> shape, std::vector<int> shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs) std::vector<array> inputs)
: array_desc_(std::make_shared<ArrayDesc>( : array_desc_(std::make_shared<ArrayDesc>(
std::move(shape), std::move(shape),
dtype, dtype,
@@ -59,15 +37,16 @@ array::array(
std::move(inputs))) {} std::move(inputs))) {}
std::vector<array> array::make_arrays( std::vector<array> array::make_arrays(
const std::vector<std::vector<int>>& shapes, std::vector<std::vector<int>> shapes,
const std::vector<Dtype>& dtypes, const std::vector<Dtype>& dtypes,
std::shared_ptr<Primitive> primitive, const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs) { const std::vector<array>& inputs) {
std::vector<array> outputs; std::vector<array> outputs;
for (int i = 0; i < shapes.size(); ++i) { for (size_t i = 0; i < shapes.size(); ++i) {
outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs)); outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
} }
for (int i = 0; i < outputs.size(); ++i) { // For each node in |outputs|, its siblings are the other nodes.
for (size_t i = 0; i < outputs.size(); ++i) {
auto siblings = outputs; auto siblings = outputs;
siblings.erase(siblings.begin() + i); siblings.erase(siblings.begin() + i);
outputs[i].set_siblings(std::move(siblings), i); outputs[i].set_siblings(std::move(siblings), i);
@@ -92,10 +71,10 @@ array::array(std::initializer_list<int> data, Dtype dtype)
/* Build an array from a shared buffer */ /* Build an array from a shared buffer */
array::array( array::array(
allocator::Buffer data, allocator::Buffer data,
const std::vector<int>& shape, std::vector<int> shape,
Dtype dtype, Dtype dtype,
deleter_t deleter) deleter_t deleter)
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) { : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, deleter); set_data(data, deleter);
} }
@@ -104,18 +83,22 @@ void array::detach() {
s.array_desc_->inputs.clear(); s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear(); s.array_desc_->siblings.clear();
s.array_desc_->position = 0; s.array_desc_->position = 0;
s.array_desc_->depth = 0;
s.array_desc_->primitive = nullptr; s.array_desc_->primitive = nullptr;
} }
array_desc_->inputs.clear(); array_desc_->inputs.clear();
array_desc_->siblings.clear(); array_desc_->siblings.clear();
array_desc_->position = 0; array_desc_->position = 0;
array_desc_->depth = 0;
array_desc_->primitive = nullptr; array_desc_->primitive = nullptr;
} }
void array::eval() { void array::eval() {
mlx::core::eval({*this}); // Ensure the array is ready to be read
if (status() == Status::scheduled) {
event().wait();
set_status(Status::available);
} else if (status() == Status::unscheduled) {
mlx::core::eval({*this});
}
} }
bool array::is_tracer() const { bool array::is_tracer() const {
@@ -164,51 +147,116 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size()); copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
} }
void array::move_shared_buffer(array other) { void array::move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
array_desc_->data = std::move(other.array_desc_->data); array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = other.strides(); array_desc_->strides = strides;
array_desc_->flags = other.flags(); array_desc_->flags = flags;
array_desc_->data_size = other.data_size(); array_desc_->data_size = data_size;
array_desc_->data_ptr = other.array_desc_->data_ptr; auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
} }
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype) void array::move_shared_buffer(array other) {
: shape(shape), dtype(dtype) { move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
std::tie(size, strides) = cum_prod(shape);
} }
array::ArrayDesc::ArrayDesc( array::~array() {
const std::vector<int>& shape, if (array_desc_ == nullptr) {
Dtype dtype, return;
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs)
: shape(shape),
dtype(dtype),
primitive(std::move(primitive)),
inputs(inputs) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
} }
depth++;
// Ignore arrays that will be detached
if (status() != array::Status::unscheduled) {
return;
}
// Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) {
bool do_detach = true;
// If all siblings have siblings.size() references except
// the one we are currently destroying (which has siblings.size() + 1)
// then there are no more external references
do_detach &= (array_desc_.use_count() == (n + 1));
for (auto& s : siblings()) {
do_detach &= (s.array_desc_.use_count() == n);
if (!do_detach) {
break;
}
}
if (do_detach) {
for (auto& s : siblings()) {
for (auto& ss : s.siblings()) {
ss.array_desc_ = nullptr;
}
s.array_desc_->siblings.clear();
}
}
}
}
void array::ArrayDesc::init() {
strides.resize(shape.size());
size = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = size;
size *= shape[i];
}
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
}
}
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
init();
} }
array::ArrayDesc::ArrayDesc( array::ArrayDesc::ArrayDesc(
std::vector<int>&& shape, std::vector<int> shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs) std::vector<array> inputs)
: shape(std::move(shape)), : shape(std::move(shape)),
dtype(dtype), dtype(dtype),
status(Status::unscheduled),
primitive(std::move(primitive)), primitive(std::move(primitive)),
inputs(std::move(inputs)) { inputs(std::move(inputs)) {
std::tie(size, strides) = cum_prod(this->shape); init();
for (auto& in : this->inputs) { }
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth); array::ArrayDesc::~ArrayDesc() {
// When an array description is destroyed it will delete a bunch of arrays
// that may also destory their corresponding descriptions and so on and so
// forth.
//
// This calls recursively the destructor and can result in stack overflow, we
// instead put them in a vector and destroy them one at a time resulting in a
// max stack depth of 2.
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
for (array& a : inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
while (!for_deletion.empty()) {
// top is going to be deleted at the end of the block *after* the arrays
// with inputs have been moved into the vector
auto top = std::move(for_deletion.back());
for_deletion.pop_back();
for (array& a : top->inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
} }
depth++;
} }
array::ArrayIterator::ArrayIterator(const array& arr, int idx) array::ArrayIterator::ArrayIterator(const array& arr, int idx)

View File

@@ -1,5 +1,6 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <cstdint> #include <cstdint>
#include <functional> #include <functional>
@@ -8,6 +9,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/dtype.h" #include "mlx/dtype.h"
#include "mlx/event.h"
namespace mlx::core { namespace mlx::core {
@@ -31,7 +33,7 @@ class array {
template <typename It> template <typename It>
array( array(
It data, It data,
const std::vector<int>& shape, std::vector<int> shape,
Dtype dtype = Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>()); TypeToDtype<typename std::iterator_traits<It>::value_type>());
@@ -47,13 +49,13 @@ class array {
template <typename T> template <typename T>
array( array(
std::initializer_list<T> data, std::initializer_list<T> data,
const std::vector<int>& shape, std::vector<int> shape,
Dtype dtype = TypeToDtype<T>()); Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */ /* Build an array from a buffer */
array( array(
allocator::Buffer data, allocator::Buffer data,
const std::vector<int>& shape, std::vector<int> shape,
Dtype dtype, Dtype dtype,
deleter_t deleter = allocator::free); deleter_t deleter = allocator::free);
@@ -112,6 +114,15 @@ class array {
return array_desc_->strides; return array_desc_->strides;
}; };
/**
* Get the stride of the corresponding dimension.
*
* This function supports negative indexing and provides
* bounds checking. */
size_t strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim);
};
/** Get the arrays data type. */ /** Get the arrays data type. */
Dtype dtype() const { Dtype dtype() const {
return array_desc_->dtype; return array_desc_->dtype;
@@ -172,22 +183,16 @@ class array {
* API may change. * API may change.
*/ */
array(
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
array( array(
std::vector<int> shape, std::vector<int> shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs); std::vector<array> inputs);
static std::vector<array> make_arrays( static std::vector<array> make_arrays(
const std::vector<std::vector<int>>& shapes, std::vector<std::vector<int>> shapes,
const std::vector<Dtype>& dtypes, const std::vector<Dtype>& dtypes,
std::shared_ptr<Primitive> primitive, const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs); const std::vector<array>& inputs);
/** A unique identifier for an array. */ /** A unique identifier for an array. */
@@ -204,7 +209,7 @@ class array {
allocator::Buffer buffer; allocator::Buffer buffer;
deleter_t d; deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free) Data(allocator::Buffer buffer, deleter_t d = allocator::free)
: buffer(buffer), d(d){}; : buffer(buffer), d(d) {};
// Not copyable // Not copyable
Data(const Data& d) = delete; Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete; Data& operator=(const Data& d) = delete;
@@ -256,6 +261,11 @@ class array {
return array_desc_->siblings; return array_desc_->siblings;
}; };
/** The array's siblings. */
std::vector<array>& siblings() {
return array_desc_->siblings;
};
void set_siblings(std::vector<array> siblings, uint16_t position) { void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings); array_desc_->siblings = std::move(siblings);
array_desc_->position = position; array_desc_->position = position;
@@ -273,11 +283,6 @@ class array {
return outputs; return outputs;
}; };
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
uint16_t graph_depth() const {
return array_desc_->depth;
}
/** Detach the array from the graph. */ /** Detach the array from the graph. */
void detach(); void detach();
@@ -314,9 +319,27 @@ class array {
return static_cast<T*>(array_desc_->data_ptr); return static_cast<T*>(array_desc_->data_ptr);
}; };
// Check if the array has been evaluated enum Status { unscheduled, scheduled, available };
bool is_evaled() const {
return array_desc_->data != nullptr; bool is_available() const {
return status() == Status::available;
}
const Status status() const {
return array_desc_->status;
}
void set_status(Status s) const {
array_desc_->status = s;
}
// Get the array's shared event
Event& event() const {
return array_desc_->event;
}
// Attach an event to a not yet evaluated array
void attach_event(Event e) const {
array_desc_->event = std::move(e);
} }
// Mark the array as a tracer array (true) or not. // Mark the array as a tracer array (true) or not.
@@ -344,12 +367,21 @@ class array {
void copy_shared_buffer(const array& other); void copy_shared_buffer(const array& other);
void move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
void move_shared_buffer(array other); void move_shared_buffer(array other);
void overwrite_descriptor(const array& other) { void overwrite_descriptor(const array& other) {
array_desc_ = other.array_desc_; array_desc_ = other.array_desc_;
} }
~array();
private: private:
// Initialize the arrays data // Initialize the arrays data
template <typename It> template <typename It>
@@ -360,7 +392,12 @@ class array {
std::vector<size_t> strides; std::vector<size_t> strides;
size_t size; size_t size;
Dtype dtype; Dtype dtype;
std::shared_ptr<Primitive> primitive{nullptr}; std::shared_ptr<Primitive> primitive;
Status status;
// An event on the array used for synchronization
Event event;
// Indicates an array is being used in a graph transform // Indicates an array is being used in a graph transform
// and should not be detached from the graph // and should not be detached from the graph
@@ -368,7 +405,7 @@ class array {
// This is a shared pointer so that *different* arrays // This is a shared pointer so that *different* arrays
// can share the underlying data buffer. // can share the underlying data buffer.
std::shared_ptr<Data> data{nullptr}; std::shared_ptr<Data> data;
// Properly offset data pointer // Properly offset data pointer
void* data_ptr{nullptr}; void* data_ptr{nullptr};
@@ -388,29 +425,26 @@ class array {
// The arrays position in the output list // The arrays position in the output list
uint32_t position{0}; uint32_t position{0};
// The depth of the array in the graph. explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
uint16_t depth{0};
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
explicit ArrayDesc( explicit ArrayDesc(
const std::vector<int>& shape, std::vector<int> shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs); std::vector<array> inputs);
explicit ArrayDesc( ~ArrayDesc();
std::vector<int>&& shape,
Dtype dtype, private:
std::shared_ptr<Primitive> primitive, // Initialize size, strides, and other metadata
std::vector<array>&& inputs); void init();
}; };
// The ArrayDesc contains the details of the materialized array including the // The ArrayDesc contains the details of the materialized array including the
// shape, strides, the data type. It also includes // shape, strides, the data type. It also includes
// the primitive which knows how to compute the array's data from its inputs // the primitive which knows how to compute the array's data from its inputs
// and the list of array's inputs for the primitive. // and the list of array's inputs for the primitive.
std::shared_ptr<ArrayDesc> array_desc_{nullptr}; std::shared_ptr<ArrayDesc> array_desc_;
}; };
template <typename T> template <typename T>
@@ -422,9 +456,9 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
template <typename It> template <typename It>
array::array( array::array(
It data, It data,
const std::vector<int>& shape, std::vector<int> shape,
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) : Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) { array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
init(data); init(data);
} }
@@ -441,9 +475,9 @@ array::array(
template <typename T> template <typename T>
array::array( array::array(
std::initializer_list<T> data, std::initializer_list<T> data,
const std::vector<int>& shape, std::vector<int> shape,
Dtype dtype /* = TypeToDtype<T>() */) Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) { : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
if (data.size() != size()) { if (data.size() != size()) {
throw std::invalid_argument( throw std::invalid_argument(
"Data size and provided shape mismatch in array construction."); "Data size and provided shape mismatch in array construction.");
@@ -465,10 +499,11 @@ T array::item() const {
if (size() != 1) { if (size() != 1) {
throw std::invalid_argument("item can only be called on arrays of size 1."); throw std::invalid_argument("item can only be called on arrays of size 1.");
} }
if (!is_evaled()) { if (status() == Status::unscheduled) {
throw std::invalid_argument( throw std::invalid_argument(
"item() const can only be called on evaled arrays"); "item() const can only be called on evaled arrays");
} }
const_cast<array*>(this)->eval();
return *data<T>(); return *data<T>();
} }
@@ -518,4 +553,15 @@ void array::init(It src) {
} }
} }
/* Utilities for determining whether a template parameter is array. */
template <typename T>
inline constexpr bool is_array_v =
std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;
template <typename... T>
inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
template <typename... T>
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>
@@ -196,6 +196,40 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
return matmul_bnns_general(a_pre, b_pre, out); return matmul_bnns_general(a_pre, b_pre, out);
} }
template <typename T>
inline void mask_matrix(
T* data,
const bool* mask,
int tile_size,
const int X,
const int Y,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str) {
int tX = (X + tile_size - 1) / tile_size;
int tY = (Y + tile_size - 1) / tile_size;
for (int i = 0; i < tX; i++) {
for (int j = 0; j < tY; j++) {
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
if (!do_mask) {
int loc_x = i * tile_size;
int loc_y = j * tile_size;
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
int size_x = std::min(tile_size, X - loc_x);
int size_y = std::min(tile_size, Y - loc_y);
for (int ii = 0; ii < size_x; ii++) {
for (int jj = 0; jj < size_y; jj++) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
}
}
}
}
}
}
} // namespace } // namespace
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) { void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {

View File

@@ -31,6 +31,8 @@ DEFAULT(ArgPartition)
DEFAULT(ArgReduce) DEFAULT(ArgReduce)
DEFAULT(ArgSort) DEFAULT(ArgSort)
DEFAULT(AsStrided) DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT(Broadcast) DEFAULT(Broadcast)
DEFAULT(Ceil) DEFAULT(Ceil)
DEFAULT(Concatenate) DEFAULT(Concatenate)
@@ -38,6 +40,7 @@ DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP) DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends) DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod) DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements)
DEFAULT(Equal) DEFAULT(Equal)
DEFAULT(Erf) DEFAULT(Erf)
DEFAULT(ErfInv) DEFAULT(ErfInv)
@@ -68,10 +71,13 @@ DEFAULT(Select)
DEFAULT(Sigmoid) DEFAULT(Sigmoid)
DEFAULT(Sign) DEFAULT(Sign)
DEFAULT(Slice) DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT_MULTI(Split) DEFAULT_MULTI(Split)
DEFAULT(Sort) DEFAULT(Sort)
DEFAULT(StopGradient) DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose) DEFAULT(Transpose)
DEFAULT(Inverse)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) { void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
@@ -297,7 +303,7 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out); set_unary_output_data(in, out);
auto size = in.data_size(); auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) { } else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::exp(x); }); unary_fp(in, out, [](auto x) { return std::exp(x); });
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@@ -306,6 +312,19 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
} }
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpm1f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
}
}
void Full::eval_cpu(const std::vector<array>& inputs, array& out) { void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
@@ -351,7 +370,7 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size(); auto size = in.data_size();
vvlog1pf( vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) { } else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::log1p(x); }); unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(

View File

@@ -10,78 +10,65 @@
namespace mlx::core { namespace mlx::core {
template <typename T, typename VT, int N> namespace {
void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
VT val = (*(VT*)x);
*(VT*)a += val;
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a++ += *x++;
}
}
}
// TODO: Add proper templates for the strided reduce algorithm so we don't have template <typename T, typename VT>
// to write max/min/sum etc. struct MinReduction {
template <typename T, typename VT, int N> T operator()(const T& a, const T& b) {
void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) { return std::min(a, b);
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_max((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::max(*a, *x);
a++;
x++;
}
} }
}
template <typename T, typename VT, int N> VT operator()(VT a, VT b) {
void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) { return simd_min(a, b);
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_min((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::min(*a, *x);
a++;
x++;
}
} }
} };
template <typename T, typename VT, int N> template <typename T, typename VT>
void _vectorized_sum(const T* x, T* accum, int size) { struct MaxReduction {
VT _sum = {0}; T operator()(const T& a, const T& b) {
while (size >= N) { return std::max(a, b);
_sum += (*(VT*)x);
x += N;
size -= N;
} }
T sum = _sum[0];
for (int i = 1; i < N; i++) { VT operator()(VT a, VT b) {
sum += _sum[i]; return simd_max(a, b);
} }
*accum += sum; };
}
template <typename T, typename VT>
struct SumReduction {
T operator()(const T& a, const T& b) {
return a + b;
}
VT operator()(VT a, VT b) {
return a + b;
}
};
template <typename T, typename VT, int N, typename Reduction>
struct StridedReduce {
void operator()(const T* x, T* accum, int size, size_t stride) {
Reduction op;
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = op((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = op(*a, *x);
a++;
x++;
}
}
}
};
} // namespace
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) { void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
@@ -94,10 +81,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out,
axes_, axes_,
0, 0,
[](const auto* x, auto* accum, int size, size_t stride) { StridedReduce<
_vectorized_strided_sum<float, simd_float16, 16>( float,
(const float*)x, (float*)accum, size, stride); simd_float16,
}, 16,
SumReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) { [](const auto* x, auto* accum, int size) {
float acc; float acc;
vDSP_sve((const float*)x, 1, &acc, size); vDSP_sve((const float*)x, 1, &acc, size);
@@ -111,10 +99,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out,
axes_, axes_,
-std::numeric_limits<float>::infinity(), -std::numeric_limits<float>::infinity(),
[](const auto* x, auto* accum, int size, size_t stride) { StridedReduce<
_vectorized_strided_max<float, simd_float16, 16>( float,
(const float*)x, (float*)accum, size, stride); simd_float16,
}, 16,
MaxReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) { [](const auto* x, auto* accum, int size) {
float max; float max;
vDSP_maxv((const float*)x, 1, &max, size); vDSP_maxv((const float*)x, 1, &max, size);
@@ -128,10 +117,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out,
axes_, axes_,
std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity(),
[](const auto* x, auto* accum, int size, size_t stride) { StridedReduce<
_vectorized_strided_min<float, simd_float16, 16>( float,
(const float*)x, (float*)accum, size, stride); simd_float16,
}, 16,
MinReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) { [](const auto* x, auto* accum, int size) {
float min; float min;
vDSP_minv((const float*)x, 1, &min, size); vDSP_minv((const float*)x, 1, &min, size);

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>
#include <limits> #include <limits>
@@ -201,7 +201,7 @@ struct NeonFp16SimdOps {
} }
}; };
template <typename T, typename VT, typename Ops, int N> template <typename T, typename AccT, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) { void softmax(const array& in, array& out) {
Ops ops; Ops ops;
@@ -218,13 +218,21 @@ void softmax(const array& in, array& out) {
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity()); VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
size_t s = M; size_t s = M;
while (s >= N) { while (s >= N) {
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum); VT vals;
if constexpr (std::is_same<T, AccT>::value) {
vals = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vals[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vmaximum = ops.max(vals, vmaximum);
current_in_ptr += N; current_in_ptr += N;
s -= N; s -= N;
} }
T maximum = ops.reduce_max(vmaximum); AccT maximum = ops.reduce_max(vmaximum);
while (s-- > 0) { while (s-- > 0) {
maximum = std::max(maximum, *current_in_ptr); maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++; current_in_ptr++;
} }
@@ -234,18 +242,29 @@ void softmax(const array& in, array& out) {
current_in_ptr = in_ptr; current_in_ptr = in_ptr;
s = M; s = M;
while (s >= N) { while (s >= N) {
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum)); VT vexp;
ops.store(current_out_ptr, vexp); if constexpr (std::is_same<T, AccT>::value) {
*(VT*)current_out_ptr = vexp; vexp = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vexp = ops.exp(ops.sub(vexp, maximum));
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, vexp);
}
vnormalizer = ops.add(vnormalizer, vexp); vnormalizer = ops.add(vnormalizer, vexp);
current_in_ptr += N; current_in_ptr += N;
current_out_ptr += N; current_out_ptr += N;
s -= N; s -= N;
} }
T normalizer = ops.reduce_add(vnormalizer); AccT normalizer = ops.reduce_add(vnormalizer);
while (s-- > 0) { while (s-- > 0) {
T _exp = std::exp(*current_in_ptr - maximum); AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = _exp; if (std::is_same<T, AccT>::value) {
*current_out_ptr = _exp;
}
normalizer += _exp; normalizer += _exp;
current_in_ptr++; current_in_ptr++;
current_out_ptr++; current_out_ptr++;
@@ -254,14 +273,33 @@ void softmax(const array& in, array& out) {
// Normalize // Normalize
current_out_ptr = out_ptr; current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M; s = M;
while (s >= N) { while (s >= N) {
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer)); if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
} else {
VT vexp;
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
for (int i = 0; i < N; ++i) {
current_out_ptr[i] = vexp[i];
}
current_in_ptr += N;
}
current_out_ptr += N; current_out_ptr += N;
s -= N; s -= N;
} }
while (s-- > 0) { while (s-- > 0) {
*current_out_ptr *= normalizer; if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++;
}
current_out_ptr++; current_out_ptr++;
} }
} }
@@ -308,15 +346,29 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
"Softmax is defined only for floating point types"); "Softmax is defined only for floating point types");
break; break;
case float32: case float32:
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>( softmax<
in, out); float,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
break; break;
case float16: case float16:
softmax< if (precise_) {
float16_t, softmax<
float16x8_t, float16_t,
NeonFp16SimdOps<float16_t, float16x8_t>, float,
8>(in, out); simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
} else {
softmax<
float16_t,
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
}
break; break;
case bfloat16: case bfloat16:
eval(inputs, out); eval(inputs, out);

View File

@@ -41,10 +41,10 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
@@ -53,6 +53,8 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
) )

View File

@@ -179,18 +179,16 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
if (is_floating_point(out.dtype())) { if (out.dtype() == float32) {
if (out.dtype() == float32) { binary_op<float>(a, b, out, detail::LogAddExp());
binary_op<float>(a, b, out, detail::LogAddExp()); } else if (out.dtype() == float16) {
} else if (out.dtype() == float16) { binary_op<float16_t>(a, b, out, detail::LogAddExp());
binary_op<float16_t>(a, b, out, detail::LogAddExp()); } else if (out.dtype() == bfloat16) {
} else if (out.dtype() == bfloat16) { binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp()); } else if (issubdtype(out.dtype(), inexact)) {
} else { std::ostringstream err;
std::ostringstream err; err << "[logaddexp] Does not support " << out.dtype();
err << "[logaddexp] Does not support " << out.dtype(); throw std::invalid_argument(err.str());
throw std::invalid_argument(err.str());
}
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
"[logaddexp] Cannot compute logaddexp for arrays with" "[logaddexp] Cannot compute logaddexp for arrays with"
@@ -238,4 +236,61 @@ void Subtract::eval(const std::vector<array>& inputs, array& out) {
binary(a, b, out, detail::Subtract()); binary(a, b, out, detail::Subtract());
} }
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) {
case BitwiseBinary::And:
dispatch_type(detail::BitwiseAnd());
break;
case BitwiseBinary::Or:
dispatch_type(detail::BitwiseOr());
break;
case BitwiseBinary::Xor:
dispatch_type(detail::BitwiseXor());
break;
case BitwiseBinary::LeftShift:
dispatch_type(detail::LeftShift());
break;
case BitwiseBinary::RightShift:
dispatch_type(detail::RightShift());
break;
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,6 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@@ -81,13 +82,27 @@ std::string build_lib_name(
const std::vector<array>& outputs, const std::vector<array>& outputs,
const std::vector<array>& tape, const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids) { const std::unordered_set<uintptr_t>& constant_ids) {
NodeNamer namer;
std::ostringstream os; std::ostringstream os;
std::ostringstream constant_hasher; std::ostringstream constant_hasher;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (auto& x : inputs) {
namer.get_name(x);
}
// The primitives describing the tape. For unary and binary primitives this // The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation. // must be enough to describe the full computation.
for (auto& a : tape) { for (auto& a : tape) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os); a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
} }
os << "_"; os << "_";
@@ -111,4 +126,102 @@ std::string build_lib_name(
return os.str(); return os.str();
} }
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape) {
bool contiguous = true;
bool all_contig = true;
bool all_row_contig = true;
bool all_col_contig = true;
int non_scalar_inputs = 0;
for (const auto& x : inputs) {
if (is_scalar(x)) {
continue;
}
non_scalar_inputs++;
bool shape_eq = x.shape() == shape;
all_contig &= (x.flags().contiguous && shape_eq);
all_row_contig &= (x.flags().row_contiguous && shape_eq);
all_col_contig &= (x.flags().col_contiguous && shape_eq);
}
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
contiguous = false;
} else if (non_scalar_inputs == 1 && !all_contig) {
contiguous = false;
} else if (non_scalar_inputs == 0 && !shape.empty()) {
contiguous = false;
}
return contiguous;
}
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers /* = false */) {
if (contiguous) {
int o = 0;
std::vector<size_t> strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Correct size
// - Not a scalar
// - Donatable
// - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o++].move_shared_buffer(in);
} else {
outputs[o++].copy_shared_buffer(in);
}
}
// Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) {
strides = in.strides();
flags = in.flags();
data_size = in.data_size();
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
}
} else {
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o].move_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
} else {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
}
o++;
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -53,4 +53,18 @@ inline bool is_scalar(const array& x) {
return x.ndim() == 0; return x.ndim() == 0;
} }
// Check if we can use a contiguous operation given inputs and the output shape
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape);
// Allocate space for the outputs possibly with input donation
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers = false);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -52,8 +52,25 @@ void* compile(
return nullptr; return nullptr;
} }
std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255
// characters. Clip file name with a little extra room and append a 16
// character hash.
constexpr int max_file_name_length = 245;
if (kernel_name.size() > max_file_name_length) {
std::ostringstream file_name;
file_name
<< std::string_view(kernel_name).substr(0, max_file_name_length - 16);
auto file_id = std::hash<std::string>{}(kernel_name);
file_name << "_" << std::hex << std::setw(16) << file_id << std::dec;
kernel_file_name = file_name.str();
} else {
kernel_file_name = kernel_name;
}
std::ostringstream shared_lib_name; std::ostringstream shared_lib_name;
shared_lib_name << "lib" << kernel_name << ".so"; shared_lib_name << "lib" << kernel_file_name << ".so";
auto shared_lib_path = get_temp_file(shared_lib_name.str()); auto shared_lib_path = get_temp_file(shared_lib_name.str());
bool lib_exists = false; bool lib_exists = false;
{ {
@@ -64,7 +81,7 @@ void* compile(
if (!lib_exists) { if (!lib_exists) {
// Open source file and write source code to it // Open source file and write source code to it
std::ostringstream source_file_name; std::ostringstream source_file_name;
source_file_name << kernel_name << ".cpp"; source_file_name << kernel_file_name << ".cpp";
auto source_file_path = get_temp_file(source_file_name.str()); auto source_file_path = get_temp_file(source_file_name.str());
std::ofstream source_file(source_file_path); std::ofstream source_file(source_file_path);
@@ -248,28 +265,7 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using // Figure out which kernel we are using
auto& shape = outputs[0].shape(); auto& shape = outputs[0].shape();
bool contiguous = true; bool contiguous = compiled_check_contiguity(inputs, shape);
{
bool all_contig = true;
bool all_row_contig = true;
bool all_col_contig = true;
int non_scalar_inputs = 0;
for (auto& x : inputs) {
if (is_scalar(x)) {
continue;
}
non_scalar_inputs++;
bool shape_eq = x.shape() == shape;
all_contig &= (x.flags().contiguous && shape_eq);
all_row_contig &= (x.flags().row_contiguous && shape_eq);
all_col_contig &= (x.flags().col_contiguous && shape_eq);
}
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
contiguous = false;
} else if (non_scalar_inputs == 1 && !all_contig) {
contiguous = false;
}
}
// Handle all broadcasting and collect function input arguments // Handle all broadcasting and collect function input arguments
std::vector<void*> args; std::vector<void*> args;
@@ -342,56 +338,8 @@ void Compiled::eval_cpu(
fn_ptr = compile(kernel_name, kernel.str()); fn_ptr = compile(kernel_name, kernel.str());
} }
// Allocate space for the outputs possibly with input donation compiled_allocate_outputs(
if (contiguous) { inputs, outputs, inputs_, constant_ids_, contiguous, false);
int o = 0;
std::vector<size_t> strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in);
}
// Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) {
strides = in.strides();
flags = in.flags();
data_size = in.data_size();
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
}
} else {
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in);
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
for (auto& x : outputs) { for (auto& x : outputs) {
args.push_back(x.data<void>()); args.push_back(x.data<void>());

View File

@@ -38,11 +38,15 @@ void slow_conv_1D(
const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int C = in.shape(2); // Input channels
const int oH = out.shape(1); // Output spatial dim const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
const int wH = wt.shape(1); // Weight spatial dim const int wH = wt.shape(1); // Weight spatial dim
const int groups = C / wt.shape(2);
const int C_per_group = wt.shape(2);
const int O_per_group = O / groups;
const size_t in_stride_N = in.strides()[0]; const size_t in_stride_N = in.strides()[0];
const size_t in_stride_H = in.strides()[1]; const size_t in_stride_H = in.strides()[1];
const size_t in_stride_C = in.strides()[2]; const size_t in_stride_C = in.strides()[2];
@@ -57,35 +61,36 @@ void slow_conv_1D(
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
for (int oh = 0; oh < oH; ++oh) { for (int oh = 0; oh < oH; ++oh) {
for (int o = 0; o < O; ++o) { for (int g = 0; g < groups; ++g) {
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O; for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.; const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
float r = 0.;
for (int wh = 0; wh < wH; ++wh) { for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int wh_flip = flip ? (wH - wh - 1) : wh; int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0]; int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
auto ih_div = std::div(ih, in_dilation[0]); auto ih_div = std::div(ih, in_dilation[0]);
if (ih >= 0 && ih < iH && ih_div.rem == 0) { if (ih >= 0 && ih < iH && ih_div.rem == 0) {
for (int c = 0; c < C; ++c) { for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>( r += static_cast<float>(
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) * in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[c * wt_stride_C]); static_cast<float>(wt_ptr[(c % C_per_group) * wt_stride_C]);
} // c } // c
} // ih check } // ih check
} // wh } // wh
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r); out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
} // o } // o
} // g
} // oh } // oh
in_ptr += in_stride_N; in_ptr += in_stride_N;
out_ptr += out_stride_N; out_ptr += out_stride_N;
} // n } // n
} }
@@ -366,11 +371,15 @@ void explicit_gemm_conv_1D_cpu(
const std::vector<int>& wt_dilation) { const std::vector<int>& wt_dilation) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim const int iH = in.shape(1); // Input spatial dim
const int C = in.shape(2); // Input channels
const int oH = out.shape(1); // Output spatial dim const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
const int wH = wt.shape(1); // Weight spatial dim const int wH = wt.shape(1); // Weight spatial dim
const int groups = C / wt.shape(2);
const int C_per_group = wt.shape(2);
const int O_per_group = O / groups;
auto conv_dtype = float32; auto conv_dtype = float32;
// Pad input // Pad input
@@ -402,6 +411,11 @@ void explicit_gemm_conv_1D_cpu(
in_padded.strides()[1], in_padded.strides()[1],
in_padded.strides()[2]}; in_padded.strides()[2]};
auto flags = in_padded.flags(); auto flags = in_padded.flags();
if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
std::swap(strided_shape[2], strided_shape[3]);
std::swap(strided_strides[2], strided_strides[3]);
}
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer( in_strided_view.copy_shared_buffer(
@@ -416,7 +430,19 @@ void explicit_gemm_conv_1D_cpu(
auto gemm_wt = wt; auto gemm_wt = wt;
auto gemm_out = out; auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) { if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
array wt_transpose(
{wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {});
wt_transpose.copy_shared_buffer(
wt,
{wt.strides(0), wt.strides(2), wt.strides(1)},
wt.flags(),
wt.size(),
0);
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
copy(wt_transpose, gemm_wt, CopyType::General);
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype = auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {}); gemm_wt = array(wt.shape(), float32, nullptr, {});
@@ -428,27 +454,29 @@ void explicit_gemm_conv_1D_cpu(
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
} }
// Perform gemm for (int g = 0; g < groups; ++g) {
cblas_sgemm( // Perform gemm
CblasRowMajor, cblas_sgemm(
CblasNoTrans, // no trans A CblasRowMajor,
CblasTrans, // transB CblasNoTrans, // no trans A
strided_reshape[0], // M CblasTrans, // transB
O, // N strided_reshape[0], // M
strided_reshape[1], // K O_per_group, // N
1.0f, // alpha C_per_group * wH, // K
in_strided.data<float>(), 1.0f, // alpha
strided_reshape[1], // lda in_strided.data<float>() + g * C_per_group * wH, // A
gemm_wt.data<float>(), wH * C, // lda
strided_reshape[1], // ldb gemm_wt.data<float>() + g * O_per_group * C_per_group * wH, // B
0.0f, // beta wH * C_per_group, // ldb
gemm_out.data<float>(), 0.0f, // beta
O // ldc gemm_out.data<float>() + g * O_per_group, // C
); O // ldc
);
// Copy results if needed // Copy results if needed
if (out.dtype() != float32) { if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector); copy(gemm_out, out, CopyType::Vector);
}
} }
} }

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <numeric> #include <numeric>
@@ -25,121 +25,196 @@ void copy_vector(const array& src, array& dst) {
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim1(const array& src, array& dst) { void copy_general_dim1(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>(); const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>(); DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0; stride_t src_idx = i_offset;
size_t dst_idx = 0; stride_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) { for (int i = 0; i < data_shape[0]; ++i) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]); dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[0]; src_idx += i_strides[0];
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_general_dim2(const array& src, array& dst) { inline void copy_general_dim1(const array& src, array& dst) {
return copy_general_dim1<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim2(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>(); const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>(); DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0; stride_t src_idx = i_offset;
size_t dst_idx = 0; stride_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) { for (int i = 0; i < data_shape[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) { for (int j = 0; j < data_shape[1]; ++j) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]); dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[1]; src_idx += i_strides[1];
} }
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1]; src_idx += i_strides[0] - i_strides[1] * data_shape[1];
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_general_dim3(const array& src, array& dst) { inline void copy_general_dim2(const array& src, array& dst) {
return copy_general_dim2<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim3(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>(); const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>(); DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0; stride_t src_idx = i_offset;
size_t dst_idx = 0; stride_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) { for (int i = 0; i < data_shape[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) { for (int j = 0; j < data_shape[1]; ++j) {
for (size_t k = 0; k < src.shape()[2]; ++k) { for (int k = 0; k < data_shape[2]; ++k) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]); dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[2]; src_idx += i_strides[2];
} }
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2]; src_idx += i_strides[1] - i_strides[2] * data_shape[2];
} }
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1]; src_idx += i_strides[0] - i_strides[1] * data_shape[1];
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_general_dim4(const array& src, array& dst) { inline void copy_general_dim3(const array& src, array& dst) {
return copy_general_dim3<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim4(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>(); const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>(); DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0; stride_t src_idx = i_offset;
size_t dst_idx = 0; stride_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) { for (int i = 0; i < data_shape[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) { for (int j = 0; j < data_shape[1]; ++j) {
for (size_t k = 0; k < src.shape()[2]; ++k) { for (int k = 0; k < data_shape[2]; ++k) {
for (size_t ii = 0; ii < src.shape()[3]; ++ii) { for (int ii = 0; ii < data_shape[3]; ++ii) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]); dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[3]; src_idx += i_strides[3];
} }
src_idx += src.strides()[2] - src.strides()[3] * src.shape()[3]; src_idx += i_strides[2] - i_strides[3] * data_shape[3];
} }
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2]; src_idx += i_strides[1] - i_strides[2] * data_shape[2];
} }
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1]; src_idx += i_strides[0] - i_strides[1] * data_shape[1];
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_general(const array& src, array& dst) { inline void copy_general_dim4(const array& src, array& dst) {
return copy_general_dim4<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
switch (src.ndim()) { switch (src.ndim()) {
case 1: case 1:
copy_general_dim1<SrcT, DstT>(src, dst); copy_general_dim1<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return; return;
case 2: case 2:
copy_general_dim2<SrcT, DstT>(src, dst); copy_general_dim2<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return; return;
case 3: case 3:
copy_general_dim3<SrcT, DstT>(src, dst); copy_general_dim3<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return; return;
case 4: case 4:
copy_general_dim4<SrcT, DstT>(src, dst); copy_general_dim4<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return; return;
} }
auto src_ptr = src.data<SrcT>(); auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>(); auto dst_ptr = dst.data<DstT>();
for (size_t i = 0; i < dst.size(); ++i) { for (size_t i = 0; i < dst.size(); ++i) {
size_t src_elem = elem_to_loc(i, src.shape(), src.strides()); stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]); dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
} }
} }
template <typename SrcT, typename DstT, int D> template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
return copy_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
inline void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset) {
return copy_general<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
}
template <typename SrcT, typename DstT, typename stride_t, int D>
inline void copy_general_general_dims( inline void copy_general_general_dims(
const array& src, const array& src,
array& dst, array& dst,
size_t offset_src, const std::vector<int>& data_shape,
size_t offset_dst) { const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
if constexpr (D > 1) { if constexpr (D > 1) {
int axis = src.ndim() - D; int axis = src.ndim() - D;
auto stride_src = src.strides()[axis]; auto stride_src = i_strides[axis];
auto stride_dst = dst.strides()[axis]; auto stride_dst = o_strides[axis];
auto N = src.shape(axis); auto N = data_shape[axis];
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
copy_general_general_dims<SrcT, DstT, D - 1>( copy_general_general_dims<SrcT, DstT, stride_t, D - 1>(
src, dst, offset_src, offset_dst); src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
offset_src += stride_src; i_offset += stride_src;
offset_dst += stride_dst; o_offset += stride_dst;
} }
} else { } else {
int axis = src.ndim() - 1; int axis = src.ndim() - 1;
auto stride_src = src.strides()[axis]; auto stride_src = i_strides[axis];
auto stride_dst = dst.strides()[axis]; auto stride_dst = o_strides[axis];
auto N = src.shape(axis); auto N = data_shape[axis];
const SrcT* src_ptr = src.data<SrcT>() + offset_src; const SrcT* src_ptr = src.data<SrcT>() + i_offset;
DstT* dst_ptr = dst.data<DstT>() + offset_dst; DstT* dst_ptr = dst.data<DstT>() + o_offset;
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
*dst_ptr = static_cast<DstT>(*src_ptr); *dst_ptr = static_cast<DstT>(*src_ptr);
src_ptr += stride_src; src_ptr += stride_src;
@@ -148,37 +223,56 @@ inline void copy_general_general_dims(
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT, typename stride_t>
void copy_general_general(const array& src, array& dst) { void copy_general_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
switch (src.ndim()) { switch (src.ndim()) {
case 1: case 1:
copy_general_general_dims<SrcT, DstT, 1>(src, dst, 0, 0); copy_general_general_dims<SrcT, DstT, stride_t, 1>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return; return;
case 2: case 2:
copy_general_general_dims<SrcT, DstT, 2>(src, dst, 0, 0); copy_general_general_dims<SrcT, DstT, stride_t, 2>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return; return;
case 3: case 3:
copy_general_general_dims<SrcT, DstT, 3>(src, dst, 0, 0); copy_general_general_dims<SrcT, DstT, stride_t, 3>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return; return;
case 4: case 4:
copy_general_general_dims<SrcT, DstT, 4>(src, dst, 0, 0); copy_general_general_dims<SrcT, DstT, stride_t, 4>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return; return;
case 5: case 5:
copy_general_general_dims<SrcT, DstT, 5>(src, dst, 0, 0); copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return; return;
} }
int size = std::accumulate( int size = std::accumulate(
src.shape().begin() - 5, src.shape().end(), 1, std::multiplies<int>()); data_shape.begin() - 5, data_shape.end(), 1, std::multiplies<int>());
for (int i = 0; i < src.size(); i += size) { for (int i = 0; i < src.size(); i += size) {
size_t offset_src = elem_to_loc(i, src.shape(), src.strides()); stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
size_t offset_dst = elem_to_loc(i, dst.shape(), dst.strides()); stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
copy_general_general_dims<SrcT, DstT, 5>(src, dst, offset_src, offset_dst); copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy(const array& src, array& dst, CopyType ctype) { inline void copy_general_general(const array& src, array& dst) {
return copy_general_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
}
template <typename SrcT, typename DstT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (ctype) { switch (ctype) {
case CopyType::Scalar: case CopyType::Scalar:
copy_single<SrcT, DstT>(src, dst); copy_single<SrcT, DstT>(src, dst);
@@ -187,54 +281,103 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_vector<SrcT, DstT>(src, dst); copy_vector<SrcT, DstT>(src, dst);
return; return;
case CopyType::General: case CopyType::General:
copy_general<SrcT, DstT>(src, dst); copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
return; return;
case CopyType::GeneralGeneral: case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst); copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
} }
} }
template <typename SrcT> template <typename SrcT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype) { void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (dst.dtype()) { switch (dst.dtype()) {
case bool_: case bool_:
copy<SrcT, bool>(src, dst, ctype); copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint8: case uint8:
copy<SrcT, uint8_t>(src, dst, ctype); copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint16: case uint16:
copy<SrcT, uint16_t>(src, dst, ctype); copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint32: case uint32:
copy<SrcT, uint32_t>(src, dst, ctype); copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint64: case uint64:
copy<SrcT, uint64_t>(src, dst, ctype); copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int8: case int8:
copy<SrcT, int8_t>(src, dst, ctype); copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int16: case int16:
copy<SrcT, int16_t>(src, dst, ctype); copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int32: case int32:
copy<SrcT, int32_t>(src, dst, ctype); copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int64: case int64:
copy<SrcT, int64_t>(src, dst, ctype); copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float16: case float16:
copy<SrcT, float16_t>(src, dst, ctype); copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float32: case float32:
copy<SrcT, float>(src, dst, ctype); copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case bfloat16: case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype); copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case complex64: case complex64:
copy<SrcT, complex64_t>(src, dst, ctype); copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
}
}
template <typename... Args>
inline void copy_inplace_dispatch(
const array& src,
array& dst,
CopyType ctype,
Args&&... args) {
switch (src.dtype()) {
case bool_:
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint8:
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint16:
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint32:
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint64:
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int8:
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int16:
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int32:
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int64:
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float16:
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float32:
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case complex64:
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
} }
} }
@@ -242,47 +385,7 @@ void copy(const array& src, array& dst, CopyType ctype) {
} // namespace } // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype) { void copy_inplace(const array& src, array& dst, CopyType ctype) {
switch (src.dtype()) { return copy_inplace_dispatch(src, dst, ctype);
case bool_:
copy<bool>(src, dst, ctype);
break;
case uint8:
copy<uint8_t>(src, dst, ctype);
break;
case uint16:
copy<uint16_t>(src, dst, ctype);
break;
case uint32:
copy<uint32_t>(src, dst, ctype);
break;
case uint64:
copy<uint64_t>(src, dst, ctype);
break;
case int8:
copy<int8_t>(src, dst, ctype);
break;
case int16:
copy<int16_t>(src, dst, ctype);
break;
case int32:
copy<int32_t>(src, dst, ctype);
break;
case int64:
copy<int64_t>(src, dst, ctype);
break;
case float16:
copy<float16_t>(src, dst, ctype);
break;
case float32:
copy<float>(src, dst, ctype);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype);
break;
case complex64:
copy<complex64_t>(src, dst, ctype);
break;
}
} }
void copy(const array& src, array& dst, CopyType ctype) { void copy(const array& src, array& dst, CopyType ctype) {
@@ -312,4 +415,62 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_inplace(src, dst, ctype); copy_inplace(src, dst, ctype);
} }
template <typename stride_t>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
template <>
void copy_inplace<int64_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<int64_t>& i_strides,
const std::vector<int64_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
@@ -26,4 +26,15 @@ enum class CopyType {
void copy(const array& src, array& dst, CopyType ctype); void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype); void copy_inplace(const array& src, array& dst, CopyType ctype);
template <typename stride_t>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -41,6 +41,8 @@ DEFAULT(ArgSort)
DEFAULT(AsType) DEFAULT(AsType)
DEFAULT(AsStrided) DEFAULT(AsStrided)
DEFAULT(Broadcast) DEFAULT(Broadcast)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT_MULTI(DivMod) DEFAULT_MULTI(DivMod)
DEFAULT(Ceil) DEFAULT(Ceil)
DEFAULT(Concatenate) DEFAULT(Concatenate)
@@ -51,11 +53,13 @@ DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP) DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends) DEFAULT_MULTI(Depends)
DEFAULT(Divide) DEFAULT(Divide)
DEFAULT(NumberOfElements)
DEFAULT(Remainder) DEFAULT(Remainder)
DEFAULT(Equal) DEFAULT(Equal)
DEFAULT(Erf) DEFAULT(Erf)
DEFAULT(ErfInv) DEFAULT(ErfInv)
DEFAULT(Exp) DEFAULT(Exp)
DEFAULT(Expm1)
DEFAULT(FFT) DEFAULT(FFT)
DEFAULT(Floor) DEFAULT(Floor)
DEFAULT(Full) DEFAULT(Full)
@@ -93,6 +97,7 @@ DEFAULT(Sign)
DEFAULT(Sin) DEFAULT(Sin)
DEFAULT(Sinh) DEFAULT(Sinh)
DEFAULT(Slice) DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT(Softmax) DEFAULT(Softmax)
DEFAULT(Sort) DEFAULT(Sort)
DEFAULT_MULTI(Split) DEFAULT_MULTI(Split)
@@ -100,9 +105,11 @@ DEFAULT(Square)
DEFAULT(Sqrt) DEFAULT(Sqrt)
DEFAULT(StopGradient) DEFAULT(StopGradient)
DEFAULT(Subtract) DEFAULT(Subtract)
DEFAULT_MULTI(SVD)
DEFAULT(Tan) DEFAULT(Tan)
DEFAULT(Tanh) DEFAULT(Tanh)
DEFAULT(Transpose) DEFAULT(Transpose)
DEFAULT(Inverse)
namespace { namespace {

View File

@@ -0,0 +1,104 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
void inverse_impl(const array& a, array& inv) {
// Lapack uses the column-major convention. We take advantage of the following
// identity to avoid transposing (see
// https://math.stackexchange.com/a/340234):
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
// The inverse is computed in place, so just copy the input to the output.
copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
for (int i = 0; i < num_matrices; i++) {
// Compute LU factorization.
sgetrf_(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch =
array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
}
void Inverse::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Inverse::eval] only supports float32.");
}
inverse_impl(inputs[0], output);
}
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::inv(a, stream())}, {ax}};
}
} // namespace mlx::core

View File

@@ -0,0 +1,23 @@
// Copyright © 2024 Apple Inc.
#pragma once
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
#if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME)
// This is to work around a change in the function signatures of lapack >= 3.9.1
// where functions taking char* also include a strlen argument, see a similar
// change in OpenCV:
// https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57
#define MLX_LAPACK_FUNC(f) LAPACK_##f
#else
#define MLX_LAPACK_FUNC(f) f##_
#endif

View File

@@ -11,7 +11,7 @@ GCC=$2
SRCDIR=$3 SRCDIR=$3
CLANG=$4 CLANG=$4
if [ $CLANG = "TRUE" ]; then if [ "$CLANG" = "TRUE" ]; then
read -r -d '' INCLUDES <<- EOM read -r -d '' INCLUDES <<- EOM
#include <cmath> #include <cmath>
#include <complex> #include <complex>

View File

@@ -0,0 +1,280 @@
// Copyright © 2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T>
inline void mask_matrix(
T* data,
const bool* mask,
int block_size,
const int X,
const int Y,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str) {
int tX = (X + block_size - 1) / block_size;
int tY = (Y + block_size - 1) / block_size;
for (int i = 0; i < tX; i++) {
for (int j = 0; j < tY; j++) {
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
if (!do_mask) {
int loc_x = i * block_size;
int loc_y = j * block_size;
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
int size_x = std::min(block_size, X - loc_x);
int size_y = std::min(block_size, Y - loc_y);
for (int ii = 0; ii < size_x; ii++) {
for (int jj = 0; jj < size_y; jj++) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
}
}
}
}
}
}
} // namespace
void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[BlockMaskedMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto& out_mask = inputs[2];
auto check_transpose = [](const array& arr, bool do_copy) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(false, stx, arr_copy);
}
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(true, sty, arr_copy);
}
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
bool has_op_mask = inputs.size() > 3;
auto [a_transposed, lda, a] = check_transpose(a_pre, has_op_mask);
auto [b_transposed, ldb, b] = check_transpose(b_pre, has_op_mask);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
auto mask_array = [](const array& mask,
float* data,
int block_size,
int batch_idx,
int X,
int Y,
size_t X_data_str,
size_t Y_data_str) {
const bool* mask_ptr = mask.data<bool>() +
elem_to_loc(mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(),
mask.strides());
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
return mask_matrix(
data,
mask_ptr,
block_size,
X,
Y,
X_data_str,
Y_data_str,
X_mask_str,
Y_mask_str);
};
for (int i = 0; i < (a.size() / (M * K)); ++i) {
// Adjust pointer
float* ai =
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
float* bi =
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides());
float* ci = out.data<float>() + M * N * i;
// Zero out blocks in a and b if needed
if (has_op_mask) {
auto& a_mask = inputs[3];
mask_array(
a_mask,
ai,
block_size_,
i,
M,
K,
a_transposed ? 1 : lda,
a_transposed ? lda : 1);
auto& b_mask = inputs[4];
mask_array(
b_mask,
bi,
block_size_,
i,
K,
N,
b_transposed ? 1 : ldb,
b_transposed ? ldb : 1);
}
// Do matmul
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0, // alpha
ai,
lda,
bi,
ldb,
0.0, // beta
ci,
out.shape(-1) // ldc
);
// Zero out blocks in out
mask_array(out_mask, ci, block_size_, i, M, N, N, 1);
}
}
void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[BlockSparseMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
// Get batch dims
auto batch_size_out = out.size() / (M * N);
size_t matrix_stride_out = M * N;
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};
};
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
std::vector<int> batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < batch_size_out; i++) {
// Get index
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)];
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)];
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0f, // alpha
a.data<float>() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
lda,
b.data<float>() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
ldb,
0.0f, // beta
out.data<float>() + matrix_stride_out * i,
out.shape(-1) // ldc
);
}
}
} // namespace mlx::core

View File

@@ -241,6 +241,13 @@ struct Exp {
} }
}; };
struct Expm1 {
template <typename T>
T operator()(T x) {
return expm1(x);
};
};
struct Floor { struct Floor {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
@@ -599,4 +606,39 @@ struct Select {
} }
}; };
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
};
} // namespace mlx::core::detail } // namespace mlx::core::detail

View File

@@ -22,7 +22,7 @@ namespace mlx::core {
void Abs::eval(const std::vector<array>& inputs, array& out) { void Abs::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (is_unsigned(in.dtype())) { if (issubdtype(in.dtype(), unsignedinteger)) {
// No-op for unsigned types // No-op for unsigned types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
@@ -37,7 +37,7 @@ void Arange::eval(const std::vector<array>& inputs, array& out) {
void ArcCos::eval(const std::vector<array>& inputs, array& out) { void ArcCos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcCos()); unary_fp(in, out, detail::ArcCos());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@@ -49,7 +49,7 @@ void ArcCos::eval(const std::vector<array>& inputs, array& out) {
void ArcCosh::eval(const std::vector<array>& inputs, array& out) { void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcCosh()); unary_fp(in, out, detail::ArcCosh());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@@ -61,7 +61,7 @@ void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
void ArcSin::eval(const std::vector<array>& inputs, array& out) { void ArcSin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcSin()); unary_fp(in, out, detail::ArcSin());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@@ -73,7 +73,7 @@ void ArcSin::eval(const std::vector<array>& inputs, array& out) {
void ArcSinh::eval(const std::vector<array>& inputs, array& out) { void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcSinh()); unary_fp(in, out, detail::ArcSinh());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@@ -85,7 +85,7 @@ void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
void ArcTan::eval(const std::vector<array>& inputs, array& out) { void ArcTan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcTan()); unary_fp(in, out, detail::ArcTan());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@@ -97,7 +97,7 @@ void ArcTan::eval(const std::vector<array>& inputs, array& out) {
void ArcTanh::eval(const std::vector<array>& inputs, array& out) { void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcTanh()); unary_fp(in, out, detail::ArcTanh());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@@ -171,7 +171,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
void Ceil::eval(const std::vector<array>& inputs, array& out) { void Ceil::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (not is_integral(in.dtype())) { if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil()); unary_fp(in, out, detail::Ceil());
} else { } else {
// No-op integer types // No-op integer types
@@ -211,7 +211,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
void Cos::eval(const std::vector<array>& inputs, array& out) { void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Cos()); unary_fp(in, out, detail::Cos());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@@ -223,7 +223,7 @@ void Cos::eval(const std::vector<array>& inputs, array& out) {
void Cosh::eval(const std::vector<array>& inputs, array& out) { void Cosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Cosh()); unary_fp(in, out, detail::Cosh());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@@ -251,6 +251,62 @@ void Depends::eval(
} }
} }
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
double numel = 1;
for (auto ax : axes_) {
numel *= inputs[0].shape(ax);
}
if (inverted_) {
numel = 1.0 / numel;
}
switch (out.dtype()) {
case bool_:
*out.data<bool>() = static_cast<bool>(numel);
break;
case uint8:
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
break;
case uint16:
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
break;
case uint32:
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
break;
case uint64:
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
break;
case int8:
*out.data<int8_t>() = static_cast<int8_t>(numel);
break;
case int16:
*out.data<int16_t>() = static_cast<int16_t>(numel);
break;
case int32:
*out.data<int32_t>() = static_cast<int32_t>(numel);
break;
case int64:
*out.data<int64_t>() = static_cast<int64_t>(numel);
break;
case float16:
*out.data<float16_t>() = static_cast<float16_t>(numel);
break;
case float32:
*out.data<float>() = static_cast<float>(numel);
break;
case bfloat16:
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
break;
case complex64:
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
break;
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) { void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
@@ -294,7 +350,7 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
void Exp::eval(const std::vector<array>& inputs, array& out) { void Exp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Exp()); unary_fp(in, out, detail::Exp());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@@ -303,10 +359,22 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Expm1::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Expm1());
} else {
throw std::invalid_argument(
"[expm1] Cannot exponentiate elements in array"
" with non floating point type.");
}
}
void Floor::eval(const std::vector<array>& inputs, array& out) { void Floor::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (not is_integral(in.dtype())) { if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor()); unary_fp(in, out, detail::Floor());
} else { } else {
// No-op integer types // No-op integer types
@@ -332,7 +400,7 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
void Log::eval(const std::vector<array>& inputs, array& out) { void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
switch (base_) { switch (base_) {
case Base::e: case Base::e:
unary_fp(in, out, detail::Log()); unary_fp(in, out, detail::Log());
@@ -354,7 +422,7 @@ void Log::eval(const std::vector<array>& inputs, array& out) {
void Log1p::eval(const std::vector<array>& inputs, array& out) { void Log1p::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Log1p()); unary_fp(in, out, detail::Log1p());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@@ -468,27 +536,80 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Reshape::eval(const std::vector<array>& inputs, array& out) { std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
assert(inputs.size() == 1); const array& in,
const auto& in = inputs[0]; const array& out) {
if (in.flags().row_contiguous) { // Special case for empty arrays or row contiguous arrays
if (in.size() == 0 || in.flags().row_contiguous) {
return {false, out.strides()};
}
// Special case for scalars
if (in.ndim() == 0) {
std::vector<size_t> out_strides(out.ndim(), 0);
return {false, out_strides};
}
// Firstly let's collapse all the contiguous dimensions of the input
auto [shape, _strides] = collapse_contiguous_dims(in);
auto& strides = _strides[0];
// If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check.
std::vector<size_t> out_strides;
bool copy_necessary = false;
int j = 0;
for (int i = 0; i < out.ndim(); i++) {
int N = out.shape(i);
if (j < shape.size() && shape[j] % N == 0) {
shape[j] /= N;
out_strides.push_back(shape[j] * strides[j]);
j += (shape[j] == 1);
} else if (N == 1) {
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
out_strides.push_back(out_strides.back());
} else {
copy_necessary = true;
break;
}
}
return {copy_necessary, out_strides};
}
void Reshape::shared_buffer_reshape(
const array& in,
const std::vector<size_t>& out_strides,
array& out) {
auto flags = in.flags();
if (flags.row_contiguous) {
// For row contiguous reshapes: // For row contiguous reshapes:
// - Shallow copy the buffer // - Shallow copy the buffer
// - If reshaping into a vector (all singleton dimensions except one) it // - If reshaping into a vector (all singleton dimensions except one) it
// becomes col contiguous again. // becomes col contiguous again.
auto flags = in.flags();
auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(in, out.strides(), flags, in.data_size()); }
} else { out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General); copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
} }
} }
void Round::eval(const std::vector<array>& inputs, array& out) { void Round::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (not is_integral(in.dtype())) { if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round()); unary_fp(in, out, detail::Round());
} else { } else {
// No-op integer types // No-op integer types
@@ -499,7 +620,7 @@ void Round::eval(const std::vector<array>& inputs, array& out) {
void Sigmoid::eval(const std::vector<array>& inputs, array& out) { void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sigmoid()); unary_fp(in, out, detail::Sigmoid());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@@ -521,7 +642,7 @@ void Sign::eval(const std::vector<array>& inputs, array& out) {
void Sin::eval(const std::vector<array>& inputs, array& out) { void Sin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sin()); unary_fp(in, out, detail::Sin());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@@ -533,7 +654,7 @@ void Sin::eval(const std::vector<array>& inputs, array& out) {
void Sinh::eval(const std::vector<array>& inputs, array& out) { void Sinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sinh()); unary_fp(in, out, detail::Sinh());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@@ -542,36 +663,33 @@ void Sinh::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Slice::eval(const std::vector<array>& inputs, array& out) { std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
assert(inputs.size() == 1); const array& in) {
if (out.size() == 0) { int64_t data_offset = 0;
out.set_data(nullptr); bool copy_needed = false;
return; std::vector<int64_t> inp_strides(in.ndim(), 0);
}
auto& in = inputs[0];
auto strides = in.strides();
auto flags = in.flags();
size_t data_offset = 0;
for (int i = 0; i < in.ndim(); ++i) { for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i]; data_offset += start_indices_[i] * in.strides()[i];
strides[i] *= strides_[i]; inp_strides[i] = in.strides()[i] * strides_[i];
copy_needed |= strides_[i] < 0;
} }
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void Slice::shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity // Compute row/col contiguity
size_t data_size = 1; auto [data_size, is_row_contiguous, is_col_contiguous] =
size_t f_stride = 1; check_contiguity(out.shape(), out_strides);
size_t b_stride = 1;
flags.row_contiguous = true; auto flags = in.flags();
flags.col_contiguous = true; flags.row_contiguous = is_row_contiguous;
for (int i = 0, ri = out.ndim() - 1; ri >= 0; i++, ri--) { flags.col_contiguous = is_col_contiguous;
flags.col_contiguous &= strides[i] == f_stride || out.shape(i) == 1;
flags.row_contiguous &= strides[ri] == b_stride || out.shape(ri) == 1;
f_stride *= out.shape(i);
b_stride *= out.shape(ri);
if (strides[i] > 0) {
data_size *= out.shape(i);
}
}
if (data_size == 1) { if (data_size == 1) {
// Broadcasted scalar array is contiguous. // Broadcasted scalar array is contiguous.
@@ -585,7 +703,87 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
flags.contiguous &= flags.row_contiguous || flags.col_contiguous; flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
} }
out.copy_shared_buffer(in, strides, flags, data_size, data_offset); out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
void Slice::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
copy_inplace<int64_t>(
/* const array& src = */ in,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ ostrides,
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General);
} else {
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, out);
}
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
}
return std::make_tuple(data_offset, inp_strides);
}
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}
// Check if materialization is needed
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(out);
// Do copy
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
copy_inplace<int64_t>(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd_strides,
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
} }
void Split::eval( void Split::eval(
@@ -664,7 +862,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
void Tan::eval(const std::vector<array>& inputs, array& out) { void Tan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Tan()); unary_fp(in, out, detail::Tan());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@@ -676,7 +874,7 @@ void Tan::eval(const std::vector<array>& inputs, array& out) {
void Tanh::eval(const std::vector<array>& inputs, array& out) { void Tanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Tanh()); unary_fp(in, out, detail::Tanh());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(

View File

@@ -6,8 +6,6 @@
namespace mlx::core { namespace mlx::core {
namespace {
enum ReductionOpType { enum ReductionOpType {
// Self-explanatory. Read everything and produce 1 output. // Self-explanatory. Read everything and produce 1 output.
ContiguousAllReduce, ContiguousAllReduce,
@@ -38,6 +36,21 @@ enum ReductionOpType {
GeneralReduce GeneralReduce
}; };
struct ReductionPlan {
ReductionOpType type;
std::vector<int> shape;
std::vector<size_t> strides;
ReductionPlan(
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
namespace {
// Helper for the ndimensional strided loop // Helper for the ndimensional strided loop
// Should this be in utils? // Should this be in utils?
inline void nd_loop( inline void nd_loop(
@@ -110,19 +123,6 @@ struct DefaultContiguousReduce {
} }
}; };
struct ReductionPlan {
ReductionOpType type;
std::vector<int> shape;
std::vector<size_t> strides;
ReductionPlan(
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) { ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// The data is all there and we are reducing over everything // The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() && if (x.size() == x.data_size() && axes.size() == x.ndim() &&

View File

@@ -1,13 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/fast_primitives.h"
namespace mlx::core::fast {
void RoPE::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("NYI");
}
} // namespace mlx::core::fast

View File

@@ -222,7 +222,7 @@ void scan_dispatch(
} }
case Scan::Min: { case Scan::Min: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; }; auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
auto init = (is_floating_point(input.dtype())) auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(std::numeric_limits<float>::infinity()) ? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max(); : std::numeric_limits<U>::max();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init); auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
@@ -232,7 +232,7 @@ void scan_dispatch(
} }
case Scan::Max: { case Scan::Max: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; }; auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
auto init = (is_floating_point(input.dtype())) auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity()) ? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max(); : std::numeric_limits<U>::max();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init); auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
@@ -10,7 +10,7 @@ namespace mlx::core {
namespace { namespace {
template <typename T> template <typename T, typename AccT>
void softmax(const array& in, array& out) { void softmax(const array& in, array& out) {
const T* in_ptr = in.data<T>(); const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>(); T* out_ptr = out.data<T>();
@@ -22,26 +22,36 @@ void softmax(const array& in, array& out) {
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) { for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
// Find the maximum // Find the maximum
current_in_ptr = in_ptr; current_in_ptr = in_ptr;
T maximum = *current_in_ptr; AccT maximum = *current_in_ptr;
for (int j = 0; j < N; j++, current_in_ptr++) { for (int j = 0; j < N; j++, current_in_ptr++) {
maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum; maximum = (maximum < *current_in_ptr) ? static_cast<AccT>(*current_in_ptr)
: maximum;
} }
// Compute the normalizer and the exponentials // Compute the normalizer and the exponentials
T normalizer = 0; AccT normalizer = 0;
current_out_ptr = out_ptr; current_out_ptr = out_ptr;
current_in_ptr = in_ptr; current_in_ptr = in_ptr;
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) { for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
T expv = std::exp(*current_in_ptr - maximum); AccT expv = std::exp(*current_in_ptr - maximum);
normalizer += expv; normalizer += expv;
*current_out_ptr = expv; if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr = expv;
}
} }
normalizer = 1 / normalizer; normalizer = 1 / normalizer;
// Normalize // Normalize
current_in_ptr = in_ptr;
current_out_ptr = out_ptr; current_out_ptr = out_ptr;
for (int j = 0; j < N; j++, current_out_ptr++) { for (int j = 0; j < N; j++, current_out_ptr++) {
*current_out_ptr *= normalizer; if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr *= normalizer;
} else {
auto v = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(v * normalizer);
current_in_ptr++;
}
} }
} }
} }
@@ -67,11 +77,15 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
} }
}; };
array in = check_input(std::move(inputs[0])); array in = check_input(std::move(inputs[0]));
out.set_data( if (in.is_donatable()) {
allocator::malloc_or_wait(in.data_size() * in.itemsize()), out.copy_shared_buffer(in);
in.data_size(), } else {
in.strides(), out.set_data(
in.flags()); allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
switch (in.dtype()) { switch (in.dtype()) {
case bool_: case bool_:
@@ -87,13 +101,21 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
"Softmax is defined only for floating point types"); "Softmax is defined only for floating point types");
break; break;
case float32: case float32:
softmax<float>(in, out); softmax<float, float>(in, out);
break; break;
case float16: case float16:
softmax<float16_t>(in, out); if (precise_) {
softmax<float16_t, float>(in, out);
} else {
softmax<float16_t, float16_t>(in, out);
}
break; break;
case bfloat16: case bfloat16:
softmax<bfloat16_t>(in, out); if (precise_) {
softmax<bfloat16_t, float>(in, out);
} else {
softmax<bfloat16_t, bfloat16_t>(in, out);
}
break; break;
case complex64: case complex64:
throw std::invalid_argument( throw std::invalid_argument(

156
mlx/backend/common/svd.cpp Normal file
View File

@@ -0,0 +1,156 @@
// Copyright © 2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack_helper.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
void svd_impl(const array& a, array& u, array& s, array& vt) {
// Lapack uses the column-major convention. To avoid having to transpose
// the input and then transpose the outputs, we swap the indices/sizes of the
// matrices and take advantage of the following identity (see
// https://math.stackexchange.com/a/30077)
// A = UΣVᵀ
// Aᵀ = VΣUᵀ
// As a result some of the indices/sizes are swapped as noted above.
// Rows and cols of the original matrix in row-major order.
const int M = a.shape(-2);
const int N = a.shape(-1);
const int K = std::min(M, N);
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy.
array in(a.shape(), float32, nullptr, {});
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
// Allocate outputs.
u.set_data(allocator::malloc_or_wait(u.nbytes()));
s.set_data(allocator::malloc_or_wait(s.nbytes()));
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
static constexpr auto job_u = "V";
static constexpr auto job_vt = "V";
static constexpr auto range = "A";
// Will contain the number of singular values after the call has returned.
int ns = 0;
float workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not used
// here but required by lapack).
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
static const int lwork_query = -1;
static const int ignored_int = 0;
static const float ignored_float = 0;
int info;
// Compute workspace size.
MLX_LAPACK_FUNC(sgesvdx)
(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
MLX_LAPACK_FUNC(sgesvdx)
(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in.data<float>() + M * N * i,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s.data<float>() + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt.data<float>() + N * N * i,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u.data<float>() + M * M * i,
/* ldvt = */ &ldvt,
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
}
}
void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[SVD::eval] only supports float32.");
}
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
}
std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
}
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
@@ -8,11 +8,12 @@
namespace mlx::core { namespace mlx::core {
inline size_t elem_to_loc( template <typename stride_t>
inline stride_t elem_to_loc(
int elem, int elem,
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<size_t>& strides) { const std::vector<stride_t>& strides) {
size_t loc = 0; stride_t loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) { for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]); auto q_and_r = ldiv(elem, shape[i]);
loc += q_and_r.rem * strides[i]; loc += q_and_r.rem * strides[i];
@@ -28,4 +29,93 @@ inline size_t elem_to_loc(int elem, const array& a) {
return elem_to_loc(elem, a.shape(), a.strides()); return elem_to_loc(elem, a.shape(), a.strides());
} }
// Collapse dims that are contiguous to possibly route to a better kernel
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
// should return {{2, 4}, {{1, 2}}}.
//
// When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned.
template <typename stride_t>
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<stride_t>> strides) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (shape.size() > 0) {
to_collapse.push_back(0);
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
for (const std::vector<stride_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1]) {
contiguous = false;
}
if (!contiguous) {
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
to_collapse.push_back(i);
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<stride_t>> out_strides(strides.size());
for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = shape[to_collapse[i]];
while (to_collapse[++i] != -1) {
current_shape *= shape[to_collapse[i]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<stride_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[i - 1]]);
}
}
return std::make_tuple(out_shape, out_strides);
}
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) {
std::vector<std::vector<size_t>> strides;
for (auto& x : xs) {
strides.emplace_back(x.strides());
}
return collapse_contiguous_dims(xs[0].shape(), strides);
}
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
inline auto collapse_contiguous_dims(Arrays&&... xs) {
return collapse_contiguous_dims(
std::vector<array>{std::forward<Arrays>(xs)...});
}
template <typename stride_t>
inline auto check_contiguity(
const std::vector<int>& shape,
const std::vector<stride_t>& strides) {
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
bool is_row_contiguous = true;
bool is_col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
}
}
return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -26,6 +26,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
@@ -33,6 +34,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp

View File

@@ -1,6 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include <mach/vm_page_size.h> #include <mach/vm_page_size.h>
#include <unistd.h> #include <unistd.h>
@@ -164,6 +165,15 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
return Buffer{nullptr}; return Buffer{nullptr};
} }
// More helpful message if maximum buffer length is exceeded
if (size > device_->maxBufferLength()) {
std::ostringstream msg;
msg << "Attempting to allocate " << size << " bytes which is greater than"
<< " the maximum allowed buffer size of " << device_->maxBufferLength()
<< " bytes.";
throw std::runtime_error(msg.str());
}
// Align up memory // Align up memory
if (size > vm_page_size) { if (size > vm_page_size) {
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size); size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
@@ -208,6 +218,11 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
return Buffer{static_cast<void*>(buf)}; return Buffer{static_cast<void*>(buf)};
} }
void MetalAllocator::clear_cache() {
std::unique_lock lk(mutex_);
buffer_cache_.clear();
}
void MetalAllocator::free(Buffer buffer) { void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr()); auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
@@ -241,6 +256,9 @@ size_t get_peak_memory() {
size_t get_cache_memory() { size_t get_cache_memory() {
return allocator().get_cache_memory(); return allocator().get_cache_memory();
} }
void clear_cache() {
return allocator().clear_cache();
}
} // namespace metal } // namespace metal

View File

@@ -26,6 +26,7 @@ class BufferCache {
size_t cache_size() { size_t cache_size() {
return pool_size_; return pool_size_;
} }
void clear();
private: private:
struct BufferHolder { struct BufferHolder {
@@ -37,7 +38,6 @@ class BufferCache {
MTL::Buffer* buf; MTL::Buffer* buf;
}; };
void clear();
void add_at_head(BufferHolder* to_add); void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove); void remove_from_list(BufferHolder* to_remove);
@@ -67,6 +67,7 @@ class MetalAllocator : public allocator::Allocator {
}; };
size_t set_cache_limit(size_t limit); size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed); size_t set_memory_limit(size_t limit, bool relaxed);
void clear_cache();
private: private:
MTL::Device* device_; MTL::Device* device_;

View File

@@ -3,6 +3,7 @@
#include <sstream> #include <sstream>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/compiled_preamble.h" #include "mlx/backend/metal/compiled_preamble.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
@@ -228,14 +229,7 @@ void Compiled::eval_gpu(
// Figure out which kernel we are using // Figure out which kernel we are using
auto& output_shape = outputs[0].shape(); auto& output_shape = outputs[0].shape();
bool contiguous = true; bool contiguous = compiled_check_contiguity(inputs, output_shape);
for (auto& x : inputs) {
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
!is_scalar(x)) {
contiguous = false;
break;
}
}
// Collapse contiguous dims to route to a faster kernel if possible. Also // Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting. // handle all broadcasting.
@@ -295,7 +289,7 @@ void Compiled::eval_gpu(
} }
} }
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
auto compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
// Put the inputs in // Put the inputs in
@@ -306,7 +300,7 @@ void Compiled::eval_gpu(
continue; continue;
} }
auto& x = inputs[i]; auto& x = inputs[i];
set_array_buffer(compute_encoder, x, cnt++); compute_encoder.set_input_array(x, cnt++);
if (!contiguous && !is_scalar(x)) { if (!contiguous && !is_scalar(x)) {
compute_encoder->setBytes( compute_encoder->setBytes(
strides[stride_idx].data(), strides[stride_idx].data(),
@@ -316,30 +310,12 @@ void Compiled::eval_gpu(
} }
} }
// Allocate space for the outputs possibly with input donation compiled_allocate_outputs(
{ inputs, outputs, inputs_, constant_ids_, contiguous, true);
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].move_shared_buffer(in);
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
// Put the outputs in // Put the outputs in
for (auto& x : outputs) { for (auto& x : outputs) {
set_array_buffer(compute_encoder, x, cnt++); compute_encoder.set_output_array(x, cnt++);
} }
// Put the output shape and strides in // Put the output shape and strides in

View File

@@ -28,10 +28,12 @@ void explicit_gemm_conv_ND_gpu(
const array& wt, const array& wt,
array out, array out,
const MLXConvParams<N>& conv_params) { const MLXConvParams<N>& conv_params) {
// Get gemm shapes
int implicit_M = out.size() / conv_params.O;
int implicit_K = wt.size() / conv_params.O;
int implicit_N = conv_params.O;
// Prepare unfolding array // Prepare unfolding array
std::vector<int> unfolded_shape = { std::vector<int> unfolded_shape{implicit_M, implicit_K};
static_cast<int>(out.size() / conv_params.O),
static_cast<int>(wt.size() / conv_params.O)};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
@@ -39,12 +41,12 @@ void explicit_gemm_conv_ND_gpu(
// Prepare unfolding kernel // Prepare unfolding kernel
std::ostringstream kname; std::ostringstream kname;
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N; kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
auto compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0); compute_encoder.set_input_array(in, 0);
set_array_buffer(compute_encoder, in_unfolded, 1); compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2); compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
@@ -59,25 +61,118 @@ void explicit_gemm_conv_ND_gpu(
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);
// Reshape weight
std::vector<int> wt_reshape{implicit_K, implicit_N};
std::vector<size_t> wt_restride{1, static_cast<size_t>(implicit_K)};
array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
auto wt_flags = wt.flags();
wt_flags.row_contiguous = false;
wt_flags.col_contiguous = true;
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
// Perform gemm // Perform gemm
std::vector<array> copies; std::vector<array> copies = {in_unfolded, wt_reshaped};
return steel_matmul( return steel_matmul(
s, s,
d, d,
/*a = */ in_unfolded, /*a = */ in_unfolded,
/*b = */ wt, /*b = */ wt_reshaped,
/*c = */ out, /*c = */ out,
/*M = */ unfolded_shape[0], /*M = */ implicit_M,
/*N = */ conv_params.O, /*N = */ implicit_N,
/*K = */ unfolded_shape[1], /*K = */ implicit_K,
/*batch_size_out = */ 1, /*batch_size_out = */ 1,
/*a_cols = */ unfolded_shape[1], /*a_cols = */ implicit_K,
/*b_cols = */ unfolded_shape[1], /*b_cols = */ implicit_K,
/*a_transposed = */ false, /*a_transposed = */ false,
/*b_transposed = */ true, /*b_transposed = */ true,
/*copies = */ copies); /*copies = */ copies);
} }
template <int N>
void explicit_gemm_conv_group_ND_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<N>& conv_params) {
const int groups = conv_params.groups;
const int C_per_group = conv_params.C / conv_params.groups;
const int O_per_group = conv_params.O / conv_params.groups;
// Get gemm shapes
const int implicit_M = out.size() / conv_params.O;
const int implicit_K = wt.size() / conv_params.O;
const int implicit_N = O_per_group;
int kernel_size = 1;
for (int i = 0; i < N; ++i) {
kernel_size *= conv_params.wS[i];
}
// Prepare unfolding array
std::vector<int> unfolded_shape{implicit_M, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
<< N;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
// Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
int tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Transpose kernel weights so that we can slice them by contiguous chunks
// of channel groups.
array wt_view(
{wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});
wt_view.copy_shared_buffer(
wt,
{wt.strides(0), 1, static_cast<size_t>(C_per_group)},
wt.flags(),
wt.size());
// Materialize
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_view, wt_transpose};
return steel_matmul_conv_groups(
s,
d,
/*a = */ in_unfolded,
/*b = */ wt_transpose,
/*c = */ out,
/*M = */ implicit_M,
/*N = */ implicit_N,
/*K = */ implicit_K,
/*a_cols = */ implicit_K * groups,
/*b_cols = */ implicit_K,
/*out_cols = */ implicit_N * groups,
/*a_transposed = */ false,
/*b_transposed = */ true,
/* groups = */ groups,
/*copies = */ copies);
}
void conv_1D_gpu( void conv_1D_gpu(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
@@ -88,6 +183,7 @@ void conv_1D_gpu(
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
int groups,
bool flip) { bool flip) {
// Make conv params // Make conv params
MLXConvParams<1> conv_params{ MLXConvParams<1> conv_params{
@@ -107,11 +203,15 @@ void conv_1D_gpu(
{wt.strides()[0], wt.strides()[1], wt.strides()[2]}, {wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */ /* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]}, {out.strides()[0], out.strides()[1], out.strides()[2]},
/* const int groups = */ 1, /* const int groups = */ groups,
/* const bool flip = */ flip}; /* const bool flip = */ flip};
// Direct to explicit gemm conv // Direct to explicit gemm conv
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); if (groups > 1) {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
} }
void slow_conv_2D_gpu( void slow_conv_2D_gpu(
@@ -129,7 +229,7 @@ void slow_conv_2D_gpu(
<< "_tm" << tm << "_tn" << tn; << "_tm" << tm << "_tn" << tn;
// Encode and dispatch kernel // Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
@@ -142,9 +242,9 @@ void slow_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(bm, bn, 1); MTL::Size group_dims = MTL::Size(bm, bn, 1);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z); MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
set_array_buffer(compute_encoder, in, 0); compute_encoder.set_input_array(in, 0);
set_array_buffer(compute_encoder, wt, 1); compute_encoder.set_input_array(wt, 1);
set_array_buffer(compute_encoder, out, 2); compute_encoder.set_output_array(out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims); compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
@@ -230,7 +330,7 @@ void implicit_gemm_conv_2D_gpu(
<< "_filter_" << (small_filter ? 's' : 'l'); << "_filter_" << (small_filter ? 's' : 'l');
// Encode and dispatch kernel // Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
@@ -243,9 +343,9 @@ void implicit_gemm_conv_2D_gpu(
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1); MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
// Encode arrays // Encode arrays
set_array_buffer(compute_encoder, in, 0); compute_encoder.set_input_array(in, 0);
set_array_buffer(compute_encoder, wt, 1); compute_encoder.set_input_array(wt, 1);
set_array_buffer(compute_encoder, out, 2); compute_encoder.set_output_array(out, 2);
// Encode params // Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
@@ -383,7 +483,7 @@ void implicit_gemm_conv_2D_general_gpu(
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn; << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
// Encode and dispatch kernel // Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
@@ -397,9 +497,9 @@ void implicit_gemm_conv_2D_general_gpu(
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z); MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
// Encode arrays // Encode arrays
set_array_buffer(compute_encoder, in, 0); compute_encoder.set_input_array(in, 0);
set_array_buffer(compute_encoder, wt, 1); compute_encoder.set_input_array(wt, 1);
set_array_buffer(compute_encoder, out, 2); compute_encoder.set_output_array(out, 2);
// Encode params // Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
@@ -500,12 +600,12 @@ void winograd_conv_2D_gpu(
std::ostringstream kname; std::ostringstream kname;
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc" kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
<< bc; << bc;
auto compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, wt, 0); compute_encoder.set_input_array(wt, 0);
set_array_buffer(compute_encoder, filt_wg, 1); compute_encoder.set_output_array(filt_wg, 1);
compute_encoder->setBytes(&C_c, sizeof(int), 2); compute_encoder->setBytes(&C_c, sizeof(int), 2);
compute_encoder->setBytes(&O_c, sizeof(int), 3); compute_encoder->setBytes(&O_c, sizeof(int), 3);
@@ -528,12 +628,12 @@ void winograd_conv_2D_gpu(
std::ostringstream kname; std::ostringstream kname;
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc" kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
<< bc; << bc;
auto compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in_padded, 0); compute_encoder.set_input_array(in_padded, 0);
set_array_buffer(compute_encoder, inp_wg, 1); compute_encoder.set_output_array(inp_wg, 1);
compute_encoder->setBytes( compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2); &conv_params_updated, sizeof(MLXConvParams<2>), 2);
@@ -576,12 +676,12 @@ void winograd_conv_2D_gpu(
std::ostringstream kname; std::ostringstream kname;
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo" kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
<< bc; << bc;
auto compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, out_wg, 0); compute_encoder.set_input_array(out_wg, 0);
set_array_buffer(compute_encoder, out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes( compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2); &conv_params_updated, sizeof(MLXConvParams<2>), 2);
@@ -710,6 +810,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,
groups_,
flip_); flip_);
} }
// Throw error // Throw error

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <sstream> #include <sstream>
@@ -12,8 +12,15 @@ namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) { if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
if (in.is_donatable() && in.itemsize() == out.itemsize()) { if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in); out.move_shared_buffer(in);
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
if (in.dtype() == out.dtype()) {
return;
}
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()), allocator::malloc_or_wait(in.data_size() * out.itemsize()),
@@ -37,15 +44,22 @@ void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream()); copy_gpu(in, out, ctype, out.primitive().stream());
} }
template <typename stride_t>
void copy_gpu_inplace( void copy_gpu_inplace(
const array& in, const array& in,
array& out, array& out,
const std::vector<int>& data_shape,
const std::vector<stride_t>& strides_in_pre,
const std::vector<stride_t>& strides_out_pre,
int64_t inp_offset,
int64_t out_offset,
CopyType ctype, CopyType ctype,
const Stream& s) { const Stream& s) {
// Try to collapse contiguous dims // Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(in, out); auto [shape, strides] = collapse_contiguous_dims(
auto& strides_in = strides[0]; data_shape, std::vector{strides_in_pre, strides_out_pre});
auto& strides_out = strides[1]; auto& strides_in_ = strides[0];
auto& strides_out_ = strides[1];
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
std::ostringstream kname; std::ostringstream kname;
@@ -69,42 +83,47 @@ void copy_gpu_inplace(
kname << "_" << shape.size(); kname << "_" << shape.size();
} }
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
bool donate_in = in.data_shared_ptr() == nullptr; bool donate_in = in.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_in ? out : in, 0);
set_array_buffer(compute_encoder, out, 1); inp_offset *= size_of(in.dtype());
out_offset *= size_of(out.dtype());
compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset);
compute_encoder.set_output_array(out, 1, out_offset);
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
size_t ndim = shape.size(); int ndim = shape.size();
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
if (ndim > 3) { if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2); set_vector_bytes(compute_encoder, shape, ndim, 2);
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3); }
if (ctype == CopyType::GeneralGeneral) { set_vector_bytes(compute_encoder, strides_in, ndim, 3);
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4); if (ctype == CopyType::GeneralGeneral) {
} set_vector_bytes(compute_encoder, strides_out, ndim, 4);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2);
if (ctype == CopyType::GeneralGeneral) {
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3);
}
} }
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes( compute_encoder->setBytes(&ndim, sizeof(int), 5);
&ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4);
} }
int dim0 = ndim > 0 ? shape[ndim - 1] : 1; int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
int dim1 = ndim > 1 ? shape[ndim - 2] : 1; int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
int rest = in.size() / (dim0 * dim1);
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
int rest = data_size / (dim0 * dim1);
// NB assuming thread_group_size is a power of 2 larger than 32 x 32 // NB assuming thread_group_size is a power of 2 larger than 32 x 32
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) { if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::copy] Must use 1024 sized block"); throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
} }
auto group_dims = get_block_dims(dim0, dim1, rest); auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);
@@ -120,4 +139,25 @@ void copy_gpu_inplace(
} }
} }
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
return copy_gpu_inplace(
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int64_t>& istride,
int64_t ioffset,
CopyType ctype,
const Stream& s) {
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
return copy_gpu_inplace(
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
@@ -7,12 +7,34 @@
namespace mlx::core { namespace mlx::core {
// Generic copy inplace
template <typename stride_t>
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype,
const Stream& s);
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s); void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& src, array& out, CopyType ctype); void copy_gpu(const array& src, array& out, CopyType ctype);
void copy_gpu_inplace( void copy_gpu_inplace(
const array& src, const array& src,
array& out, array& out,
CopyType ctype, CopyType ctype,
const Stream& s); const Stream& s);
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int64_t>& istride,
int64_t ioffset,
CopyType ctype,
const Stream& s);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-24 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h> #include <dlfcn.h>
#include <cstdlib> #include <cstdlib>
@@ -11,7 +11,9 @@
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/mps/gemm.h" #include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
namespace fs = std::filesystem; namespace fs = std::filesystem;
@@ -20,9 +22,9 @@ namespace mlx::core::metal {
namespace { namespace {
// TODO nicer way to set this or possibly expose as an environment variable // TODO nicer way to set this or possibly expose as an environment variable
static constexpr int MAX_BUFFERS_PER_QUEUE = 12; constexpr int MAX_BUFFERS_PER_QUEUE = 12;
static constexpr const char* default_mtllib_path = METAL_PATH; constexpr const char* default_mtllib_path = METAL_PATH;
auto load_device() { auto load_device() {
auto devices = MTL::CopyAllDevices(); auto devices = MTL::CopyAllDevices();
@@ -127,7 +129,7 @@ Device::~Device() {
b.second.second->release(); b.second.second->release();
} }
for (auto& e : encoder_map_) { for (auto& e : encoder_map_) {
e.second->release(); (*e.second)->release();
} }
for (auto& k : kernel_map_) { for (auto& k : kernel_map_) {
k.second->release(); k.second->release();
@@ -145,6 +147,7 @@ void Device::new_queue(int index) {
// We lock this as a critical section for safety // We lock this as a critical section for safety
const std::lock_guard<std::mutex> lock(mtx_); const std::lock_guard<std::mutex> lock(mtx_);
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE); auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
debug_set_stream_queue_label(q, index);
if (!q) { if (!q) {
throw std::runtime_error( throw std::runtime_error(
"[metal::Device] Failed to make new command queue."); "[metal::Device] Failed to make new command queue.");
@@ -197,22 +200,25 @@ void Device::commit_command_buffer(int index) {
void Device::end_encoding(int index) { void Device::end_encoding(int index) {
auto eit = encoder_map_.find(index); auto eit = encoder_map_.find(index);
if (eit != encoder_map_.end()) { if (eit != encoder_map_.end()) {
eit->second->endEncoding(); (*eit->second)->endEncoding();
eit->second->release(); (*eit->second)->release();
encoder_map_.erase(eit); encoder_map_.erase(eit);
} }
} }
MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) { CommandEncoder& Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index); auto eit = encoder_map_.find(index);
if (eit == encoder_map_.end()) { if (eit == encoder_map_.end()) {
auto cb = get_command_buffer(index); auto cb = get_command_buffer(index);
auto compute_encoder = cb->computeCommandEncoder(); auto compute_encoder =
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
// Increment ref count so the buffer is not garbage collected // Increment ref count so the buffer is not garbage collected
compute_encoder->retain(); compute_encoder->retain();
eit = encoder_map_.insert({index, compute_encoder}).first; eit = encoder_map_
.emplace(index, std::make_unique<CommandEncoder>(compute_encoder))
.first;
} }
return eit->second; return *(eit->second);
} }
void Device::register_library( void Device::register_library(
@@ -259,8 +265,7 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
// Throw error if unable to compile library // Throw error if unable to compile library
if (!mtl_lib) { if (!mtl_lib) {
std::ostringstream msg; std::ostringstream msg;
msg << "[metal::Device] Unable to load build metal library from source" msg << "[metal::Device] Unable to build metal library from source" << "\n";
<< "\n";
if (error) { if (error) {
msg << error->localizedDescription()->utf8String() << "\n"; msg << error->localizedDescription()->utf8String() << "\n";
} }
@@ -279,8 +284,7 @@ MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
// Throw error if unable to compile library // Throw error if unable to compile library
if (!mtl_lib) { if (!mtl_lib) {
std::ostringstream msg; std::ostringstream msg;
msg << "[metal::Device] Unable to load build stitched metal library" msg << "[metal::Device] Unable to build stitched metal library" << "\n";
<< "\n";
if (error) { if (error) {
msg << error->localizedDescription()->utf8String() << "\n"; msg << error->localizedDescription()->utf8String() << "\n";
} }
@@ -538,11 +542,12 @@ Device& device(mlx::core::Device) {
return metal_device; return metal_device;
} }
std::shared_ptr<void> new_scoped_memory_pool() { std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
auto dtor = [](void* ptr) { auto dtor = [](void* ptr) {
static_cast<NS::AutoreleasePool*>(ptr)->release(); static_cast<NS::AutoreleasePool*>(ptr)->release();
}; };
return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor); return std::unique_ptr<void, std::function<void(void*)>>(
NS::AutoreleasePool::alloc()->init(), dtor);
} }
void new_stream(Stream stream) { void new_stream(Stream stream) {
@@ -551,4 +556,15 @@ void new_stream(Stream stream) {
} }
} }
std::unordered_map<std::string, std::variant<std::string, size_t>>
device_info() {
auto raw_device = device(default_device()).mtl_device();
auto arch = std::string(raw_device->architecture()->name()->utf8String());
return {
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()}};
}
} // namespace mlx::core::metal } // namespace mlx::core::metal

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-24 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
@@ -7,10 +7,12 @@
#include <mutex> #include <mutex>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <dlfcn.h> #include <dlfcn.h>
#include <filesystem> #include <filesystem>
#include "mlx/array.h"
#include "mlx/device.h" #include "mlx/device.h"
namespace fs = std::filesystem; namespace fs = std::filesystem;
@@ -34,6 +36,70 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
using MTLFCList = using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>; std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
struct CommandEncoder {
CommandEncoder(MTL::ComputeCommandEncoder* enc)
: enc(enc), concurrent(false) {};
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc) : enc(enc) {
enc.concurrent = true;
}
~ConcurrentContext() {
enc.concurrent = false;
enc.outputs.insert(
enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
enc.concurrent_outputs.clear();
}
private:
CommandEncoder& enc;
};
MTL::ComputeCommandEncoder* operator->() {
return enc;
}
void set_input_array(const array& a, int idx, int offset = 0) {
auto r_buf =
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
void set_output_array(array& a, int idx, int offset = 0) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) {
concurrent_outputs.insert(buf);
} else {
outputs.insert(buf);
}
}
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
}
private:
MTL::ComputeCommandEncoder* enc;
bool concurrent;
std::unordered_set<MTL::Resource*> outputs;
std::unordered_set<MTL::Resource*> concurrent_outputs;
};
class Device { class Device {
public: public:
Device(); Device();
@@ -51,7 +117,7 @@ class Device {
int get_command_buffer_ops(int index); int get_command_buffer_ops(int index);
void increment_command_buffer_ops(int index); void increment_command_buffer_ops(int index);
void commit_command_buffer(int index); void commit_command_buffer(int index);
MTL::ComputeCommandEncoder* get_command_encoder(int index); CommandEncoder& get_command_encoder(int index);
void end_encoding(int index); void end_encoding(int index);
void register_library( void register_library(
@@ -132,7 +198,7 @@ class Device {
MTL::Device* device_; MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_; std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_; std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
std::unordered_map<int32_t, MTL::ComputeCommandEncoder*> encoder_map_; std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_; std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
std::unordered_map<std::string, MTL::Library*> library_map_; std::unordered_map<std::string, MTL::Library*> library_map_;
std::mutex mtx_; std::mutex mtx_;

View File

@@ -0,0 +1,30 @@
// Copyright © 2024 Apple Inc.
#include "mlx/event.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal_impl.h"
namespace mlx::core {
Event::Event(const Stream& stream) : stream_(stream) {
auto dtor = [](void* ptr) {
auto p = metal::new_scoped_memory_pool();
static_cast<MTL::SharedEvent*>(ptr)->release();
};
auto p = metal::new_scoped_memory_pool();
event_ = std::shared_ptr<void>(
metal::device(stream.device).mtl_device()->newSharedEvent(), dtor);
}
void Event::wait() {
if (!static_cast<MTL::SharedEvent*>(raw_event().get())
->waitUntilSignaledValue(value(), -1)) {
throw std::runtime_error("[Event::wait] Timed out");
}
}
void Event::signal() {
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
}
} // namespace mlx::core

View File

@@ -1,12 +1,106 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/mlx.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) { void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto& in = inputs[0]; auto& in = inputs[0];
throw std::runtime_error("[FFT] NYI for Metal backend.");
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ ||
in.dtype() != complex64 || out.dtype() != complex64) {
// Could also fallback to CPU implementation here.
throw std::runtime_error(
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
}
size_t n = in.shape(axes_[0]);
if (!is_power_of_2(n) || n > 2048 || n < 4) {
throw std::runtime_error(
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048");
}
// Make sure that the array is contiguous and has stride 1 in the FFT dim
std::vector<array> copies;
auto check_input = [this, &copies, &s](const array& x) {
// TODO: Pass the strides to the kernel so
// we can avoid the copy when x is not contiguous.
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous ||
x.flags().col_contiguous;
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
std::vector<size_t> strides;
size_t cur_stride = x.shape(axes_[0]);
for (int axis = 0; axis < x.ndim(); axis++) {
if (axis == axes_[0]) {
strides.push_back(1);
} else {
strides.push_back(cur_stride);
cur_stride *= x.shape(axis);
}
}
auto flags = x.flags();
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) {
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1);
f_stride *= x.shape(i);
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
b_stride *= x.shape(ri);
}
// This is probably over-conservative
flags.contiguous = false;
x_copy.set_data(
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags);
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
copies.push_back(x_copy);
return x_copy;
}
};
const array& in_contiguous = check_input(inputs[0]);
// TODO: allow donation here
out.set_data(
allocator::malloc_or_wait(out.nbytes()),
in_contiguous.data_size(),
in_contiguous.strides(),
in_contiguous.flags());
// We use n / 4 threads by default since radix-4
// is the largest single threaded radix butterfly
// we currently implement.
size_t m = n / 4;
size_t batch = in.size() / in.shape(axes_[0]);
auto& compute_encoder = d.get_command_encoder(s.index);
{
std::ostringstream kname;
kname << "fft_" << n;
auto kernel = d.get_kernel(kname.str());
bool donated = in.data_shared_ptr() == nullptr;
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_contiguous, 0);
compute_encoder.set_output_array(out, 1);
auto group_dims = MTL::Size(1, m, 1);
auto grid_dims = MTL::Size(batch, m, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -16,7 +16,7 @@ namespace mlx::core {
namespace { namespace {
static constexpr int METAL_MAX_INDEX_ARRAYS = 10; constexpr int METAL_MAX_INDEX_ARRAYS = 10;
} // namespace } // namespace
@@ -49,7 +49,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
kname << "_" << idx_ndim; kname << "_" << idx_ndim;
} }
auto compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
@@ -81,8 +81,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
// Set all the buffers // Set all the buffers
set_array_buffer(compute_encoder, src, 0); compute_encoder.set_input_array(src, 0);
set_array_buffer(compute_encoder, out, 1); compute_encoder.set_output_array(out, 1);
// Set source info // Set source info
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2); compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
@@ -103,7 +103,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
// Set index buffers // Set index buffers
for (int i = 1; i < nidx + 1; ++i) { for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i); compute_encoder.set_input_array(inputs[i], 20 + i);
} }
// Launch grid // Launch grid
@@ -183,7 +183,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
kname << "_" << nidx; kname << "_" << nidx;
auto compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
auto& upd = inputs.back(); auto& upd = inputs.back();
@@ -192,8 +192,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
// Set all the buffers // Set all the buffers
set_array_buffer(compute_encoder, upd, 1); compute_encoder.set_input_array(upd, 1);
set_array_buffer(compute_encoder, out, 2); compute_encoder.set_output_array(out, 2);
// Set update info // Set update info
uint upd_ndim = upd.ndim(); uint upd_ndim = upd.ndim();
@@ -201,19 +201,16 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
for (int i = idx_ndim; i < upd.ndim(); ++i) { for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i); upd_size *= upd.shape(i);
} }
if (index_nd1_specialization) { if (index_nd1_specialization) {
bool upd_col_contiguous = upd.flags().col_contiguous;
compute_encoder->setBytes( compute_encoder->setBytes(
out.shape().data(), out.shape().size() * sizeof(int), 3); out.shape().data(), out.shape().size() * sizeof(int), 3);
compute_encoder->setBytes( compute_encoder->setBytes(
out.strides().data(), out.strides().size() * sizeof(size_t), 4); out.strides().data(), out.strides().size() * sizeof(size_t), 4);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5); compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6);
// Set index buffers // Set index buffers
for (int i = 1; i < nidx + 1; ++i) { for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i); compute_encoder.set_input_array(inputs[i], 20 + i);
} }
// Launch grid // Launch grid
@@ -283,7 +280,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Set index buffers // Set index buffers
for (int i = 1; i < nidx + 1; ++i) { for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i); compute_encoder.set_input_array(inputs[i], 20 + i);
} }
// Launch grid // Launch grid

View File

@@ -7,6 +7,7 @@ set(
${CMAKE_CURRENT_SOURCE_DIR}/complex.h ${CMAKE_CURRENT_SOURCE_DIR}/complex.h
${CMAKE_CURRENT_SOURCE_DIR}/defines.h ${CMAKE_CURRENT_SOURCE_DIR}/defines.h
${CMAKE_CURRENT_SOURCE_DIR}/erf.h ${CMAKE_CURRENT_SOURCE_DIR}/erf.h
${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h ${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
${CMAKE_CURRENT_SOURCE_DIR}/unary.h ${CMAKE_CURRENT_SOURCE_DIR}/unary.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h ${CMAKE_CURRENT_SOURCE_DIR}/utils.h
@@ -20,9 +21,12 @@ set(
"binary_two" "binary_two"
"conv" "conv"
"copy" "copy"
"fft"
"gemv" "gemv"
"quantized" "quantized"
"random" "random"
"rms_norm"
"layer_norm"
"rope" "rope"
"scan" "scan"
"scaled_dot_product_attention" "scaled_dot_product_attention"
@@ -35,11 +39,17 @@ set(
) )
function(build_kernel_base TARGET SRCFILE DEPS) function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS}
-gline-tables-only
-frecord-sources)
endif()
add_custom_command( add_custom_command(
COMMAND xcrun -sdk macosx metal -Wall -Wextra COMMAND xcrun -sdk macosx metal
-fno-fast-math ${METAL_FLAGS}
-c ${SRCFILE} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -I${PROJECT_SOURCE_DIR}
-o ${TARGET}.air -o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} DEPENDS ${SRCFILE} ${DEPS}
OUTPUT ${TARGET}.air OUTPUT ${TARGET}.air

View File

@@ -11,22 +11,22 @@ template <typename T>
out[index] = start + index * step; out[index] = start + index * step;
} }
#define instantiate_arange(tname, type) \ #define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] \ template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
[[kernel]] void arange<type>( \ constant const type& start, \
constant const type& start, \ constant const type& step, \
constant const type& step, \ device type* out, \
device type* out, \ uint index [[thread_position_in_grid]]);
uint index [[thread_position_in_grid]]);
instantiate_arange(uint8, uint8_t) // clang-format off
instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t) instantiate_arange(uint16, uint16_t)
instantiate_arange(uint32, uint32_t) instantiate_arange(uint32, uint32_t)
instantiate_arange(uint64, uint64_t) instantiate_arange(uint64, uint64_t)
instantiate_arange(int8, int8_t) instantiate_arange(int8, int8_t)
instantiate_arange(int16, int16_t) instantiate_arange(int16, int16_t)
instantiate_arange(int32, int32_t) instantiate_arange(int32, int32_t)
instantiate_arange(int64, int64_t) instantiate_arange(int64, int64_t)
instantiate_arange(float16, half) instantiate_arange(float16, half)
instantiate_arange(float32, float) instantiate_arange(float32, float)
instantiate_arange(bfloat16, bfloat16_t) instantiate_arange(bfloat16, bfloat16_t) // clang-format on

View File

@@ -18,7 +18,8 @@ struct ArgMin {
static constexpr constant U init = Limits<U>::max; static constexpr constant U init = Limits<U>::max;
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) { IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
if (best.val > current.val || (best.val == current.val && best.index > current.index)) { if (best.val > current.val ||
(best.val == current.val && best.index > current.index)) {
return current; return current;
} else { } else {
return best; return best;
@@ -26,11 +27,12 @@ struct ArgMin {
} }
template <int N> template <int N>
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) { IndexValPair<U>
for (int i=0; i<N; i++) { reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] < best.val) { if (vals[i] < best.val) {
best.val = vals[i]; best.val = vals[i];
best.index = offset+i; best.index = offset + i;
} }
} }
return best; return best;
@@ -42,7 +44,8 @@ struct ArgMax {
static constexpr constant U init = Limits<U>::min; static constexpr constant U init = Limits<U>::min;
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) { IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
if (best.val < current.val || (best.val == current.val && best.index > current.index)) { if (best.val < current.val ||
(best.val == current.val && best.index > current.index)) {
return current; return current;
} else { } else {
return best; return best;
@@ -50,11 +53,12 @@ struct ArgMax {
} }
template <int N> template <int N>
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) { IndexValPair<U>
for (int i=0; i<N; i++) { reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] > best.val) { if (vals[i] > best.val) {
best.val = vals[i]; best.val = vals[i];
best.index = offset+i; best.index = offset + i;
} }
} }
return best; return best;
@@ -64,19 +68,16 @@ struct ArgMax {
template <typename U> template <typename U>
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) { IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
return IndexValPair<U>{ return IndexValPair<U>{
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
simd_shuffle_down(data.val, delta)
};
} }
template <typename T, typename Op, int N_READS> template <typename T, typename Op, int N_READS>
[[kernel]] void arg_reduce_general( [[kernel]] void arg_reduce_general(
const device T *in [[buffer(0)]], const device T* in [[buffer(0)]],
device uint32_t *out [[buffer(1)]], device uint32_t* out [[buffer(1)]],
const device int *shape [[buffer(2)]], const device int* shape [[buffer(2)]],
const device size_t *in_strides [[buffer(3)]], const device size_t* in_strides [[buffer(3)]],
const device size_t *out_strides [[buffer(4)]], const device size_t* out_strides [[buffer(4)]],
const device size_t& ndim [[buffer(5)]], const device size_t& ndim [[buffer(5)]],
const device size_t& axis_stride [[buffer(6)]], const device size_t& axis_stride [[buffer(6)]],
const device size_t& axis_size [[buffer(7)]], const device size_t& axis_size [[buffer(7)]],
@@ -86,7 +87,6 @@ template <typename T, typename Op, int N_READS>
uint simd_size [[threads_per_simdgroup]], uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Shapes and strides *do not* contain the reduction axis. The reduction size // Shapes and strides *do not* contain the reduction axis. The reduction size
// and stride are provided in axis_stride and axis_size. // and stride are provided in axis_stride and axis_size.
// //
@@ -113,13 +113,13 @@ template <typename T, typename Op, int N_READS>
threadgroup IndexValPair<T> local_data[32]; threadgroup IndexValPair<T> local_data[32];
// Loop over the reduction axis in lsize*N_READS buckets // Loop over the reduction axis in lsize*N_READS buckets
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) { for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
// Read the current value // Read the current value
uint32_t current_index = r*lsize*N_READS + lid*N_READS; uint32_t current_index = r * lsize * N_READS + lid * N_READS;
uint32_t offset = current_index; uint32_t offset = current_index;
const device T * current_in = in + in_idx + current_index * axis_stride; const device T* current_in = in + in_idx + current_index * axis_stride;
T vals[N_READS]; T vals[N_READS];
for (int i=0; i<N_READS; i++) { for (int i = 0; i < N_READS; i++) {
vals[i] = (current_index < axis_size) ? *current_in : T(Op::init); vals[i] = (current_index < axis_size) ? *current_in : T(Op::init);
current_index++; current_index++;
current_in += axis_stride; current_in += axis_stride;
@@ -130,7 +130,7 @@ template <typename T, typename Op, int N_READS>
// need to reduce across the thread group. // need to reduce across the thread group.
// First per simd reduction. // First per simd reduction.
for (uint offset=simd_size/2; offset>0; offset/=2) { for (uint offset = simd_size / 2; offset > 0; offset /= 2) {
IndexValPair<T> neighbor = simd_shuffle_down(best, offset); IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
best = op.reduce(best, neighbor); best = op.reduce(best, neighbor);
} }
@@ -149,7 +149,7 @@ template <typename T, typename Op, int N_READS>
if (simd_lane_id < simd_groups) { if (simd_lane_id < simd_groups) {
best = local_data[simd_lane_id]; best = local_data[simd_lane_id];
} }
for (uint offset=simd_size/2; offset>0; offset/=2) { for (uint offset = simd_size / 2; offset > 0; offset /= 2) {
IndexValPair<T> neighbor = simd_shuffle_down(best, offset); IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
best = op.reduce(best, neighbor); best = op.reduce(best, neighbor);
} }
@@ -161,24 +161,25 @@ template <typename T, typename Op, int N_READS>
} }
#define instantiate_arg_reduce_helper(name, itype, op) \ #define instantiate_arg_reduce_helper(name, itype, op) \
template [[host_name(name)]] \ template [[host_name(name)]] [[kernel]] void \
[[kernel]] void arg_reduce_general<itype, op<itype>, 4>( \ arg_reduce_general<itype, op<itype>, 4>( \
const device itype *in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
device uint32_t * out [[buffer(1)]], \ device uint32_t* out [[buffer(1)]], \
const device int *shape [[buffer(2)]], \ const device int* shape [[buffer(2)]], \
const device size_t *in_strides [[buffer(3)]], \ const device size_t* in_strides [[buffer(3)]], \
const device size_t *out_strides [[buffer(4)]], \ const device size_t* out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \ const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \ const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \ const device size_t& axis_size [[buffer(7)]], \
uint gid [[thread_position_in_grid]], \ uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \ uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \ uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \ uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_arg_reduce(name, itype) \ // clang-format off
#define instantiate_arg_reduce(name, itype) \
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \ instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax) instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
@@ -193,4 +194,4 @@ instantiate_arg_reduce(int32, int32_t)
instantiate_arg_reduce(int64, int64_t) instantiate_arg_reduce(int64, int64_t)
instantiate_arg_reduce(float16, half) instantiate_arg_reduce(float16, half)
instantiate_arg_reduce(float32, float) instantiate_arg_reduce(float32, float)
instantiate_arg_reduce(bfloat16, bfloat16_t) instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on

View File

@@ -229,3 +229,38 @@ struct LogicalOr {
return x || y; return x || y;
}; };
}; };
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
};

View File

@@ -77,7 +77,8 @@ template <typename T, typename U, typename Op>
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides); auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides); auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]); c[out_idx] = Op()(a[a_idx], b[b_idx]);
} }
@@ -92,7 +93,8 @@ template <typename T, typename U, typename Op, int DIM>
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides); auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]); c[out_idx] = Op()(a[idx.x], b[idx.y]);
} }
@@ -112,111 +114,118 @@ template <typename T, typename U, typename Op>
c[out_idx] = Op()(a[idx.x], b[idx.y]); c[out_idx] = Op()(a[idx.x], b[idx.y]);
} }
#define instantiate_binary(name, itype, otype, op, bopt) \ #define instantiate_binary(name, itype, otype, op, bopt) \
template [[host_name(name)]] \ template \
[[kernel]] void binary_op_##bopt<itype, otype, op>( \ [[host_name(name)]] [[kernel]] void binary_op_##bopt<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \ #define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name(name "_" #dims)]] \ template [[host_name(name "_" #dims)]] [[kernel]] void \
[[kernel]] void binary_op_g_nd<itype, otype, op, dims>( \ binary_op_g_nd<itype, otype, op, dims>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
constant const int shape[dims], \ constant const int shape[dims], \
constant const size_t a_strides[dims], \ constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \ constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_g_nd(name, itype, otype, op) \ #define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name(name "_1")]] \ template [[host_name(name "_1")]] [[kernel]] void \
[[kernel]] void binary_op_g_nd1<itype, otype, op>( \ binary_op_g_nd1<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
constant const size_t& a_stride, \ constant const size_t& a_stride, \
constant const size_t& b_stride, \ constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \ template [[host_name(name "_2")]] [[kernel]] void \
[[kernel]] void binary_op_g_nd2<itype, otype, op>( \ binary_op_g_nd2<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
constant const size_t a_strides[2], \ constant const size_t a_strides[2], \
constant const size_t b_strides[2], \ constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \ uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \ uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \ template [[host_name(name "_3")]] [[kernel]] void \
[[kernel]] void binary_op_g_nd3<itype, otype, op>( \ binary_op_g_nd3<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
constant const size_t a_strides[3], \ constant const size_t a_strides[3], \
constant const size_t b_strides[3], \ constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op, 4) \ instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5) instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_g(name, itype, otype, op) \
#define instantiate_binary_g(name, itype, otype, op) \ template [[host_name(name)]] [[kernel]] void binary_op_g<itype, otype, op>( \
template [[host_name(name)]] \ device const itype* a, \
[[kernel]] void binary_op_g<itype, otype, op>( \ device const itype* b, \
device const itype* a, \ device otype* c, \
device const itype* b, \ constant const int* shape, \
device otype* c, \ constant const size_t* a_strides, \
constant const int* shape, \ constant const size_t* b_strides, \
constant const size_t* a_strides, \ constant const int& ndim, \
constant const size_t* b_strides, \ uint3 index [[thread_position_in_grid]], \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_all(name, tname, itype, otype, op) \ #define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \ instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \ instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \ instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \ instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op) \ instantiate_binary_g("g" #name #tname, itype, otype, op) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op) instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on
#define instantiate_binary_float(name, op) \ // clang-format off
instantiate_binary_all(name, float16, half, half, op) \ #define instantiate_binary_integer(name, op) \
instantiate_binary_all(name, float32, float, float, op) \ instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \ instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \ instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \ instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op) \ instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \ instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \ instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) \ instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on
// clang-format off
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on
// clang-format off
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_integer(name, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \ instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op) instantiate_binary_float(name, op) // clang-format on
#define instantiate_binary_types_bool(name, op) \ // clang-format off
instantiate_binary_all(name, bool_, bool, bool, op) \ #define instantiate_binary_types_bool(name, op) \
instantiate_binary_all(name, uint8, uint8_t, bool, op) \ instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint16, uint16_t, bool, op) \ instantiate_binary_all(name, uint8, uint8_t, bool, op) \
instantiate_binary_all(name, uint32, uint32_t, bool, op) \ instantiate_binary_all(name, uint16, uint16_t, bool, op) \
instantiate_binary_all(name, uint64, uint64_t, bool, op) \ instantiate_binary_all(name, uint32, uint32_t, bool, op) \
instantiate_binary_all(name, int8, int8_t, bool, op) \ instantiate_binary_all(name, uint64, uint64_t, bool, op) \
instantiate_binary_all(name, int16, int16_t, bool, op) \ instantiate_binary_all(name, int8, int8_t, bool, op) \
instantiate_binary_all(name, int32, int32_t, bool, op) \ instantiate_binary_all(name, int16, int16_t, bool, op) \
instantiate_binary_all(name, int64, int64_t, bool, op) \ instantiate_binary_all(name, int32, int32_t, bool, op) \
instantiate_binary_all(name, float16, half, bool, op) \ instantiate_binary_all(name, int64, int64_t, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \ instantiate_binary_all(name, float16, half, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \ instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
instantiate_binary_all(name, complex64, complex64_t, bool, op) instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on
// clang-format off
instantiate_binary_types(add, Add) instantiate_binary_types(add, Add)
instantiate_binary_types(div, Divide) instantiate_binary_types(div, Divide)
instantiate_binary_types_bool(eq, Equal) instantiate_binary_types_bool(eq, Equal)
@@ -241,3 +250,13 @@ instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr) instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd) instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
// Bitwise ops only need integer types and bool (except for l/r shift)
instantiate_binary_integer(bitwise_and, BitwiseAnd)
instantiate_binary_all(bitwise_and, bool_, bool, bool, BitwiseAnd)
instantiate_binary_integer(bitwise_or, BitwiseOr)
instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
instantiate_binary_integer(bitwise_xor, BitwiseXor)
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
instantiate_binary_integer(left_shift, LeftShift)
instantiate_binary_integer(right_shift, RightShift) // clang-format on

View File

@@ -3,28 +3,42 @@
#include <metal_integer> #include <metal_integer>
#include <metal_math> #include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
struct FloorDivide { struct FloorDivide {
template <typename T> T operator()(T x, T y) { return x / y; } template <typename T>
template <> float operator()(float x, float y) { return trunc(x / y); } T operator()(T x, T y) {
template <> half operator()(half x, half y) { return trunc(x / y); } return x / y;
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return trunc(x / y); } }
template <>
float operator()(float x, float y) {
return trunc(x / y);
}
template <>
half operator()(half x, half y) {
return trunc(x / y);
}
template <>
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
return trunc(x / y);
}
}; };
struct Remainder { struct Remainder {
template <typename T> template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) { metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y; return x % y;
} }
template <typename T> template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) { metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y; auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) { if (r != 0 && (r < 0 != y < 0)) {
r += y; r += y;
} }
return r; return r;
} }
template <typename T> template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) { metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
@@ -32,10 +46,11 @@ struct Remainder {
if (r != 0 && (r < 0 != y < 0)) { if (r != 0 && (r < 0 != y < 0)) {
r += y; r += y;
} }
return r; return r;
} }
template <> complex64_t operator()(complex64_t x, complex64_t y) { template <>
return x % y; complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
} }
}; };
@@ -50,7 +65,6 @@ template <typename T, typename U, typename Op1, typename Op2>
d[index] = Op2()(a[0], b[0]); d[index] = Op2()(a[0], b[0]);
} }
template <typename T, typename U, typename Op1, typename Op2> template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_ss( [[kernel]] void binary_op_ss(
device const T* a, device const T* a,
@@ -139,7 +153,8 @@ template <typename T, typename U, typename Op1, typename Op2>
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides); auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides); auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[a_idx], b[b_idx]); c[out_idx] = Op1()(a[a_idx], b[b_idx]);
d[out_idx] = Op2()(a[a_idx], b[b_idx]); d[out_idx] = Op2()(a[a_idx], b[b_idx]);
} }
@@ -156,7 +171,8 @@ template <typename T, typename U, typename Op1, typename Op2, int DIM>
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides); auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[idx.x], b[idx.y]); c[out_idx] = Op1()(a[idx.x], b[idx.y]);
d[out_idx] = Op2()(a[idx.x], b[idx.y]); d[out_idx] = Op2()(a[idx.x], b[idx.y]);
} }
@@ -180,99 +196,102 @@ template <typename T, typename U, typename Op1, typename Op2>
} }
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \ #define instantiate_binary(name, itype, otype, op1, op2, bopt) \
template [[host_name(name)]] \ template [[host_name(name)]] [[kernel]] void \
[[kernel]] void binary_op_##bopt<itype, otype, op1, op2>( \ binary_op_##bopt<itype, otype, op1, op2>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
device otype* d, \ device otype* d, \
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \ #define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
template [[host_name(name "_" #dims)]] \ template [[host_name(name "_" #dims)]] [[kernel]] void \
[[kernel]] void binary_op_g_nd<itype, otype, op1, op2, dims>( \ binary_op_g_nd<itype, otype, op1, op2, dims>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
device otype* d, \ device otype* d, \
constant const int shape[dims], \ constant const int shape[dims], \
constant const size_t a_strides[dims], \ constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \ constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \ #define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
template [[host_name(name "_1")]] \ template [[host_name(name "_1")]] [[kernel]] void \
[[kernel]] void binary_op_g_nd1<itype, otype, op1, op2>( \ binary_op_g_nd1<itype, otype, op1, op2>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
device otype* d, \ device otype* d, \
constant const size_t& a_stride, \ constant const size_t& a_stride, \
constant const size_t& b_stride, \ constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \ template [[host_name(name "_2")]] [[kernel]] void \
[[kernel]] void binary_op_g_nd2<itype, otype, op1, op2>( \ binary_op_g_nd2<itype, otype, op1, op2>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
device otype* d, \ device otype* d, \
constant const size_t a_strides[2], \ constant const size_t a_strides[2], \
constant const size_t b_strides[2], \ constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \ uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \ uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \ template [[host_name(name "_3")]] [[kernel]] void \
[[kernel]] void binary_op_g_nd3<itype, otype, op1, op2>( \ binary_op_g_nd3<itype, otype, op1, op2>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
device otype* d, \ device otype* d, \
constant const size_t a_strides[3], \ constant const size_t a_strides[3], \
constant const size_t b_strides[3], \ constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \ instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on
#define instantiate_binary_g(name, itype, otype, op1, op2) \ #define instantiate_binary_g(name, itype, otype, op1, op2) \
template [[host_name(name)]] \ template [[host_name(name)]] [[kernel]] void \
[[kernel]] void binary_op_g<itype, otype, op2, op2>( \ binary_op_g<itype, otype, op2, op2>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
device otype* d, \ device otype* d, \
constant const int* shape, \ constant const int* shape, \
constant const size_t* a_strides, \ constant const size_t* a_strides, \
constant const size_t* b_strides, \ constant const size_t* b_strides, \
constant const int& ndim, \ constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \ #define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \ instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \ instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \ instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \ instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \ instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on
#define instantiate_binary_float(name, op1, op2) \ // clang-format off
instantiate_binary_all(name, float16, half, half, op1, op2) \ #define instantiate_binary_float(name, op1, op2) \
instantiate_binary_all(name, float16, half, half, op1, op2) \
instantiate_binary_all(name, float32, float, float, op1, op2) \ instantiate_binary_all(name, float32, float, float, op1, op2) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on
#define instantiate_binary_types(name, op1, op2) \ // clang-format off
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \ #define instantiate_binary_types(name, op1, op2) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \ instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \ instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \ instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \ instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \ instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \ instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \ instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \ instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \ instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
instantiate_binary_float(name, op1, op2) instantiate_binary_float(name, op1, op2)
instantiate_binary_types(divmod, FloorDivide, Remainder) instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on

View File

@@ -22,7 +22,7 @@ struct complex64_t {
float imag; float imag;
// Constructors // Constructors
constexpr complex64_t(float real, float imag) : real(real), imag(imag){}; constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
// Conversions to complex64_t // Conversions to complex64_t
template < template <

View File

@@ -1,13 +1,11 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_simdgroup_matrix> #include <metal_simdgroup_matrix>
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#define MLX_MTL_CONST static constant constexpr const #define MLX_MTL_CONST static constant constexpr const
@@ -23,17 +21,18 @@ template <typename T, int N>
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
const constant MLXConvParams<N>* params [[buffer(2)]], const constant MLXConvParams<N>* params [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
int filter_size = params->C; int filter_size = params->C;
for(short i = 0; i < N; i++) filter_size *= params->wS[i]; for (short i = 0; i < N; i++)
filter_size *= params->wS[i];
int out_pixels = 1; int out_pixels = 1;
for(short i = 0; i < N; i++) out_pixels *= params->oS[i]; for (short i = 0; i < N; i++)
out_pixels *= params->oS[i];
// Set out // Set out
out += gid.z * filter_size + gid.y * (params->C); out += gid.z * filter_size + gid.y * (params->C);
// Corrdinates in input // Coordinates in input
int is[N] = {0}; int is[N] = {0};
// gid.z: N oS (Batch and row in unfolded output) // gid.z: N oS (Batch and row in unfolded output)
@@ -46,11 +45,11 @@ template <typename T, int N>
bool valid = n < params->N; bool valid = n < params->N;
// Unroll dimensions // Unroll dimensions
for (int i = N - 1; i >= 0; --i) { for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]); int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]); int ws_ = (wS % params->wS[i]);
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_; ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i]; int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
@@ -64,10 +63,10 @@ template <typename T, int N>
wS /= params->wS[i]; wS /= params->wS[i];
} }
if(valid) { if (valid) {
size_t in_offset = n * params->in_strides[0]; size_t in_offset = n * params->in_strides[0];
for(int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
in_offset += is[i] * params->in_strides[i + 1]; in_offset += is[i] * params->in_strides[i + 1];
} }
@@ -75,21 +74,91 @@ template <typename T, int N>
} else { } else {
out[gid.x] = T(0); out[gid.x] = T(0);
} }
} }
#define instantiate_naive_unfold_nd(name, itype, n) \ // This kernel unfolds the input array of size (N, *spatial_dims, C)
template [[host_name("naive_unfold_nd_" #name "_" #n)]] \ // into an array of size (N x *spatial_dims, C x *kernel_dims).
[[kernel]] void naive_unfold_Nd( \ template <typename T, int N>
const device itype* in [[buffer(0)]], \ [[kernel]] void naive_unfold_transpose_Nd(
device itype* out [[buffer(1)]], \ const device T* in [[buffer(0)]],
const constant MLXConvParams<n>* params [[buffer(2)]], \ device T* out [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]); const constant MLXConvParams<N>* params [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
int filter_size = params->C;
for (short i = 0; i < N; i++)
filter_size *= params->wS[i];
#define instantiate_naive_unfold_nd_dims(name, itype) \ int out_pixels = 1;
instantiate_naive_unfold_nd(name, itype, 1) \ for (short i = 0; i < N; i++)
instantiate_naive_unfold_nd(name, itype, 2) \ out_pixels *= params->oS[i];
instantiate_naive_unfold_nd(name, itype, 3)
// Set out
out += gid.z * filter_size + gid.x * (filter_size / params->C);
// Coordinates in input
int is[N] = {0};
// gid.z: N oS (Batch and row in unfolded output)
// gid.y: wS (Filter location to unfold input)
// gid.x: C (channel)
int n = (gid.z) / out_pixels;
int oS = (gid.z) % out_pixels;
int wS = gid.y;
bool valid = n < params->N;
// Unroll dimensions
for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]);
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
int is_max = 1 + params->idil[i] * (params->iS[i] - 1);
valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0);
is[i] = is_ / params->idil[i];
oS /= params->oS[i];
wS /= params->wS[i];
out += ws_ * params->str[i];
}
if (valid) {
size_t in_offset = n * params->in_strides[0];
for (int i = 0; i < N; ++i) {
in_offset += is[i] * params->in_strides[i + 1];
}
out[0] = in[in_offset + gid.x];
} else {
out[0] = T(0);
}
}
#define instantiate_naive_unfold_nd(name, itype, n) \
template [[host_name("naive_unfold_nd_" #name "_" #n)]] [[kernel]] void \
naive_unfold_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]); \
template \
[[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \
naive_unfold_transpose_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]);
#define instantiate_naive_unfold_nd_dims(name, itype) \
instantiate_naive_unfold_nd(name, itype, 1) instantiate_naive_unfold_nd( \
name, itype, 2) instantiate_naive_unfold_nd(name, itype, 3)
instantiate_naive_unfold_nd_dims(float32, float); instantiate_naive_unfold_nd_dims(float32, float);
instantiate_naive_unfold_nd_dims(float16, half); instantiate_naive_unfold_nd_dims(float16, half);
@@ -99,12 +168,13 @@ instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
/// Slow and naive conv2d kernels /// Slow and naive conv2d kernels
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename T, template <
const int BM, /* Threadgroup rows (in threads) */ typename T,
const int BN, /* Threadgroup cols (in threads) */ const int BM, /* Threadgroup rows (in threads) */
const int TM, /* Thread rows (in elements) */ const int BN, /* Threadgroup cols (in threads) */
const int TN, /* Thread cols (in elements) */ const int TM, /* Thread rows (in elements) */
const int BC = 16> const int TN, /* Thread cols (in elements) */
const int BC = 16>
[[kernel]] void naive_conv_2d( [[kernel]] void naive_conv_2d(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]], const device T* wt [[buffer(1)]],
@@ -114,7 +184,6 @@ template <typename T,
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
(void)simd_gid; (void)simd_gid;
(void)simd_lid; (void)simd_lid;
@@ -123,80 +192,82 @@ template <typename T,
int out_o = tid.y * BN * TN + lid.y * TN; int out_o = tid.y * BN * TN + lid.y * TN;
int out_hw = tid.x * BM * TM + lid.x * TM; int out_hw = tid.x * BM * TM + lid.x * TM;
int out_h[TM]; int out_h[TM];
int out_w[TN]; int out_w[TN];
for(int m = 0; m < TM; ++m) { for (int m = 0; m < TM; ++m) {
int mm = (out_hw + m); int mm = (out_hw + m);
out_h[m] = mm / params.oS[1]; out_h[m] = mm / params.oS[1];
out_w[m] = mm % params.oS[1]; out_w[m] = mm % params.oS[1];
} }
T in_local[TM]; T in_local[TM];
T wt_local[TN]; T wt_local[TN];
T out_local[TM * TN] = {T(0)}; T out_local[TM * TN] = {T(0)};
for(int h = 0; h < params.wS[0]; ++h) { for (int h = 0; h < params.wS[0]; ++h) {
for(int w = 0; w < params.wS[1]; ++w) { for (int w = 0; w < params.wS[1]; ++w) {
for(int c = 0; c < params.C; ++c) { for (int c = 0; c < params.C; ++c) {
// Local in // Local in
for(int m = 0; m < TM; m++) { for (int m = 0; m < TM; m++) {
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0]; int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1]; int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1]; bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0); in_local[m] = valid
? in[i * params.in_strides[1] + j * params.in_strides[2] + c]
: T(0);
} }
// Load weight // Load weight
for (int n = 0; n < TN; ++n) { for (int n = 0; n < TN; ++n) {
int o = out_o + n; int o = out_o + n;
wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] + wt_local[n] = o < params.O
h * params.wt_strides[1] + ? wt[o * params.wt_strides[0] + h * params.wt_strides[1] +
w * params.wt_strides[2] + c] : T(0); w * params.wt_strides[2] + c]
: T(0);
} }
// Accumulate // Accumulate
for(int m = 0; m < TM; ++m) { for (int m = 0; m < TM; ++m) {
for(int n = 0; n < TN; ++n) { for (int n = 0; n < TN; ++n) {
out_local[m * TN + n] += in_local[m] * wt_local[n]; out_local[m * TN + n] += in_local[m] * wt_local[n];
} }
} }
} }
} }
} }
for(int m = 0; m < TM; ++m) { for (int m = 0; m < TM; ++m) {
for(int n = 0; n < TN; ++n) { for (int n = 0; n < TN; ++n) {
if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O) if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] &&
out[out_h[m] * params.out_strides[1] + (out_o + n) < params.O)
out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n]; out[out_h[m] * params.out_strides[1] +
out_w[m] * params.out_strides[2] + out_o + n] =
out_local[m * TN + n];
} }
} }
} }
// Instantiations // Instantiations
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \ #define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \ template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \
[[kernel]] void naive_conv_2d<itype, bm, bn, tm, tn>( \ "_tn" #tn)]] [[kernel]] void \
const device itype* in [[buffer(0)]], \ naive_conv_2d<itype, bm, bn, tm, tn>( \
const device itype* wt [[buffer(1)]], \ const device itype* in [[buffer(0)]], \
device itype* out [[buffer(2)]], \ const device itype* wt [[buffer(1)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \ device itype* out [[buffer(2)]], \
uint3 tid [[threadgroup_position_in_grid]], \ const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 lid [[thread_position_in_threadgroup]], \ uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]); uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_naive_conv_2d_blocks(name, itype) \ #define instantiate_naive_conv_2d_blocks(name, itype) \
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \ instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4) instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
instantiate_naive_conv_2d_blocks(float32, float); instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half); instantiate_naive_conv_2d_blocks(float16, half);
@@ -207,9 +278,7 @@ instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <int M, int R, int S> template <int M, int R, int S>
struct WinogradTransforms { struct WinogradTransforms {};
};
template <> template <>
struct WinogradTransforms<6, 3, 8> { struct WinogradTransforms<6, 3, 8> {
@@ -218,36 +287,36 @@ struct WinogradTransforms<6, 3, 8> {
MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1; MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1;
MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8; MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8;
MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f}, {1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{ 0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f}, {0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f}, {-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
{ 0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f}, {0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
{ 5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f}, {5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
{ 0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f}, {0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f}, {-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f}, {0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
}; };
MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f}, {1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{ 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}, {1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
{ 1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f}, {1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
{ 1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f}, {1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
{ 1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f}, {1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
{ 1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f}, {1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
{ 1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f}, {1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f}, {0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
}; };
MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00, 0.00, 0.00}, {1.00, 0.00, 0.00},
{ -2.0/9.00, -2.0/9.00, -2.0/9.00}, {-2.0 / 9.00, -2.0 / 9.00, -2.0 / 9.00},
{ -2.0/9.00, 2.0/9.00, -2.0/9.00}, {-2.0 / 9.00, 2.0 / 9.00, -2.0 / 9.00},
{ 1.0/90.0, 1.0/45.0, 2.0/45.0}, {1.0 / 90.0, 1.0 / 45.0, 2.0 / 45.0},
{ 1.0/90.0, -1.0/45.0, 2.0/45.0}, {1.0 / 90.0, -1.0 / 45.0, 2.0 / 45.0},
{ 32.0/45.0, 16.0/45.0, 8.0/45.0}, {32.0 / 45.0, 16.0 / 45.0, 8.0 / 45.0},
{ 32.0/45.0, -16.0/45.0, 8.0/45.0}, {32.0 / 45.0, -16.0 / 45.0, 8.0 / 45.0},
{ 0.00, 0.00, 1.00}, {0.00, 0.00, 1.00},
}; };
}; };
@@ -255,12 +324,9 @@ constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8]; constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8]; constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
template <typename T, template <typename T, int BC = 32, int BO = 4, int M = 6, int R = 3>
int BC = 32, [[kernel, max_total_threads_per_threadgroup(BO * 32)]] void
int BO = 4, winograd_conv_2d_weight_transform(
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform(
const device T* wt_in [[buffer(0)]], const device T* wt_in [[buffer(0)]],
device T* wt_out [[buffer(1)]], device T* wt_out [[buffer(1)]],
const constant int& C [[buffer(2)]], const constant int& C [[buffer(2)]],
@@ -268,7 +334,6 @@ template <typename T,
uint tid [[threadgroup_position_in_grid]], uint tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) { uint simd_lane_id [[thread_index_in_simdgroup]]) {
using WGT = WinogradTransforms<M, R, 8>; using WGT = WinogradTransforms<M, R, 8>;
// Get lane position in simdgroup // Get lane position in simdgroup
@@ -288,35 +353,37 @@ template <typename T,
// Move to the correct output filter // Move to the correct output filter
size_t ko = BO * tid + simd_group_id; size_t ko = BO * tid + simd_group_id;
wt_in += ko * R * R * C; wt_in += ko * R * R * C;
// wt_out is stored transposed (A x A x C x O) // wt_out is stored transposed (A x A x C x O)
short ohw_0 = sm * 8 + sn; short ohw_0 = sm * 8 + sn;
short ohw_1 = sm * 8 + sn + 1; short ohw_1 = sm * 8 + sn + 1;
device T* wt_out_0 = wt_out + ohw_0 * C * O + ko; device T* wt_out_0 = wt_out + ohw_0 * C * O + ko;
device T* wt_out_1 = wt_out + ohw_1 * C * O + ko; device T* wt_out_1 = wt_out + ohw_1 * C * O + ko;
// Prepare shared memory // Prepare shared memory
threadgroup T Ws[BO][R][R][BC]; threadgroup T Ws[BO][R][R][BC];
// Loop over C // Loop over C
for(int bc = 0; bc < C; bc += BC) { for (int bc = 0; bc < C; bc += BC) {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Read into shared memory // Read into shared memory
for(int kh = 0; kh < R; ++kh) { for (int kh = 0; kh < R; ++kh) {
for(int kw = 0; kw < R; ++kw) { for (int kw = 0; kw < R; ++kw) {
for(int kc = simd_lane_id; kc < BC; kc += 32) { for (int kc = simd_lane_id; kc < BC; kc += 32) {
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc]; Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
} }
} }
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result // Do transform and store the result
for(int c = 0; c < BC; ++c) { for (int c = 0; c < BC; ++c) {
simdgroup_matrix<T, 8, 8> g; simdgroup_matrix<T, 8, 8> g;
g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0); g.thread_elements()[0] =
g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0); sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
g.thread_elements()[1] =
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt; simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
wt_out_0[c * O] = g_out.thread_elements()[0]; wt_out_0[c * O] = g_out.thread_elements()[0];
@@ -327,27 +394,23 @@ template <typename T,
wt_out_0 += BC * O; wt_out_0 += BC * O;
wt_out_1 += BC * O; wt_out_1 += BC * O;
} }
} }
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \ #define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc)]]\ template [[host_name("winograd_conv_2d_weight_transform_" #name \
[[kernel]] void winograd_conv_2d_weight_transform<itype, bc>(\ "_bc" #bc)]] [[kernel]] void \
const device itype* wt_in [[buffer(0)]],\ winograd_conv_2d_weight_transform<itype, bc>( \
device itype* wt_out [[buffer(1)]],\ const device itype* wt_in [[buffer(0)]], \
const constant int& C [[buffer(2)]],\ device itype* wt_out [[buffer(1)]], \
const constant int& O [[buffer(3)]],\ const constant int& C [[buffer(2)]], \
uint tid [[threadgroup_position_in_grid]],\ const constant int& O [[buffer(3)]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]],\ uint tid [[threadgroup_position_in_grid]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]]); uint simd_lane_id [[thread_index_in_simdgroup]]);
template <typename T, template <typename T, int BC, int WM, int WN, int M = 6, int R = 3>
int BC, [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
int WM, winograd_conv_2d_input_transform(
int WN,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_input_transform(
const device T* inp_in [[buffer(0)]], const device T* inp_in [[buffer(0)]],
device T* inp_out [[buffer(1)]], device T* inp_out [[buffer(1)]],
const constant MLXConvParams<2>& params [[buffer(2)]], const constant MLXConvParams<2>& params [[buffer(2)]],
@@ -356,7 +419,6 @@ template <typename T,
uint3 tgp_per_grid [[threadgroups_per_grid]], uint3 tgp_per_grid [[threadgroups_per_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) { uint simd_lane_id [[thread_index_in_simdgroup]]) {
(void)lid; (void)lid;
using WGT = WinogradTransforms<M, R, 8>; using WGT = WinogradTransforms<M, R, 8>;
@@ -387,46 +449,48 @@ template <typename T,
int bw = M * tid.x + kw; int bw = M * tid.x + kw;
// Move to the correct input tile // Move to the correct input tile
inp_in += tid.z * params.in_strides[0] inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +
+ bh * params.in_strides[1] bw * params.in_strides[2];
+ bw * params.in_strides[2];
// Pre compute strides // Pre compute strides
int jump_in[TH][TW]; int jump_in[TH][TW];
for(int h = 0; h < TH; h++) { for (int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) { for (int w = 0; w < TW; w++) {
jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2]; jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];
} }
} }
// inp_out is stored interleaved (A x A x tiles x C) // inp_out is stored interleaved (A x A x tiles x C)
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z; size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x; size_t tile_id =
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t ohw_0 = sm * 8 + sn; size_t ohw_0 = sm * 8 + sn;
size_t ohw_1 = sm * 8 + sn + 1; size_t ohw_1 = sm * 8 + sn + 1;
device T* inp_out_0 = inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C; device T* inp_out_0 =
device T* inp_out_1 = inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C; inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
device T* inp_out_1 =
inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
// Prepare shared memory // Prepare shared memory
threadgroup T Is[A][A][BC]; threadgroup T Is[A][A][BC];
// Loop over C // Loop over C
for(int bc = 0; bc < params.C; bc += BC) { for (int bc = 0; bc < params.C; bc += BC) {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Read into shared memory // Read into shared memory
for(int h = 0; h < TH; h++) { for (int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) { for (int w = 0; w < TW; w++) {
const device T* in_ptr = inp_in + jump_in[h][w]; const device T* in_ptr = inp_in + jump_in[h][w];
for(int c = simd_lane_id; c < BC; c += 32) { for (int c = simd_lane_id; c < BC; c += 32) {
Is[kh + h][kw + w][c] = in_ptr[c]; Is[kh + h][kw + w][c] = in_ptr[c];
} }
} }
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result // Do transform and store the result
for(int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) { for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> I; simdgroup_matrix<T, 8, 8> I;
I.thread_elements()[0] = Is[sm][sn][c]; I.thread_elements()[0] = Is[sm][sn][c];
I.thread_elements()[1] = Is[sm][sn + 1][c]; I.thread_elements()[1] = Is[sm][sn + 1][c];
@@ -440,28 +504,24 @@ template <typename T,
inp_out_0 += BC; inp_out_0 += BC;
inp_out_1 += BC; inp_out_1 += BC;
} }
} }
#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \ #define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \
template [[host_name("winograd_conv_2d_input_transform_" #name "_bc" #bc)]]\ template [[host_name("winograd_conv_2d_input_transform_" #name \
[[kernel]] void winograd_conv_2d_input_transform<itype, bc, 2, 2>(\ "_bc" #bc)]] [[kernel]] void \
const device itype* inp_in [[buffer(0)]],\ winograd_conv_2d_input_transform<itype, bc, 2, 2>( \
device itype* inp_out [[buffer(1)]],\ const device itype* inp_in [[buffer(0)]], \
const constant MLXConvParams<2>& params [[buffer(2)]],\ device itype* inp_out [[buffer(1)]], \
uint3 tid [[threadgroup_position_in_grid]],\ const constant MLXConvParams<2>& params [[buffer(2)]], \
uint3 lid [[thread_position_in_threadgroup]],\ uint3 tid [[threadgroup_position_in_grid]], \
uint3 tgp_per_grid [[threadgroups_per_grid]],\ uint3 lid [[thread_position_in_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]],\ uint3 tgp_per_grid [[threadgroups_per_grid]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]]); uint simd_lane_id [[thread_index_in_simdgroup]]);
template <typename T, template <typename T, int BO, int WM, int WN, int M = 6, int R = 3>
int BO, [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
int WM, winograd_conv_2d_output_transform(
int WN,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_output_transform(
const device T* out_in [[buffer(0)]], const device T* out_in [[buffer(0)]],
device T* out_out [[buffer(1)]], device T* out_out [[buffer(1)]],
const constant MLXConvParams<2>& params [[buffer(2)]], const constant MLXConvParams<2>& params [[buffer(2)]],
@@ -470,7 +530,6 @@ template <typename T,
uint3 tgp_per_grid [[threadgroups_per_grid]], uint3 tgp_per_grid [[threadgroups_per_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) { uint simd_lane_id [[thread_index_in_simdgroup]]) {
(void)lid; (void)lid;
using WGT = WinogradTransforms<M, R, 8>; using WGT = WinogradTransforms<M, R, 8>;
@@ -503,57 +562,59 @@ template <typename T,
int bw = M * tid.x + kw; int bw = M * tid.x + kw;
// Move to the correct input tile // Move to the correct input tile
out_out += tid.z * params.out_strides[0] out_out += tid.z * params.out_strides[0] + bh * params.out_strides[1] +
+ bh * params.out_strides[1] bw * params.out_strides[2];
+ bw * params.out_strides[2];
// Pre compute strides // Pre compute strides
int jump_in[TH][TW]; int jump_in[TH][TW];
for(int h = 0; h < TH; h++) { for (int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) { for (int w = 0; w < TW; w++) {
bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]); bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);
jump_in[h][w] = valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1; jump_in[h][w] =
valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
} }
} }
// out_in is stored interleaved (A x A x tiles x O) // out_in is stored interleaved (A x A x tiles x O)
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z; size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x; size_t tile_id =
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t ohw_0 = sm * 8 + sn; size_t ohw_0 = sm * 8 + sn;
size_t ohw_1 = sm * 8 + sn + 1; size_t ohw_1 = sm * 8 + sn + 1;
const device T* out_in_0 = out_in + ohw_0 * N_TILES * params.O + tile_id * params.O; const device T* out_in_0 =
const device T* out_in_1 = out_in + ohw_1 * N_TILES * params.O + tile_id * params.O; out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
const device T* out_in_1 =
out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
// Prepare shared memory // Prepare shared memory
threadgroup T Os[M][M][BO]; threadgroup T Os[M][M][BO];
// Loop over O // Loop over O
for(int bo = 0; bo < params.O; bo += BO) { for (int bo = 0; bo < params.O; bo += BO) {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result // Do transform and store the result
for(int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) { for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> O_mat; simdgroup_matrix<T, 8, 8> O_mat;
O_mat.thread_elements()[0] = out_in_0[c]; O_mat.thread_elements()[0] = out_in_0[c];
O_mat.thread_elements()[1] = out_in_1[c]; O_mat.thread_elements()[1] = out_in_1[c];
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B)); simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
if((sm < M) && (sn < M)) { if ((sm < M) && (sn < M)) {
Os[sm][sn][c] = O_out.thread_elements()[0]; Os[sm][sn][c] = O_out.thread_elements()[0];
} }
if((sm < M) && ((sn + 1) < M)) { if ((sm < M) && ((sn + 1) < M)) {
Os[sm][sn + 1][c] = O_out.thread_elements()[1]; Os[sm][sn + 1][c] = O_out.thread_elements()[1];
} }
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Read out from shared memory // Read out from shared memory
for(int h = 0; h < TH; h++) { for (int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) { for (int w = 0; w < TW; w++) {
if(jump_in[h][w] >= 0) { if (jump_in[h][w] >= 0) {
device T* out_ptr = out_out + jump_in[h][w]; device T* out_ptr = out_out + jump_in[h][w];
for(int c = simd_lane_id; c < BO; c += 32) { for (int c = simd_lane_id; c < BO; c += 32) {
out_ptr[c] = Os[kh + h][kw + w][c]; out_ptr[c] = Os[kh + h][kw + w][c];
} }
} }
@@ -564,25 +625,27 @@ template <typename T,
out_in_0 += BO; out_in_0 += BO;
out_in_1 += BO; out_in_1 += BO;
} }
} }
#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \ #define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \
template [[host_name("winograd_conv_2d_output_transform_" #name "_bo" #bo)]]\ template [[host_name("winograd_conv_2d_output_transform_" #name \
[[kernel]] void winograd_conv_2d_output_transform<itype, bo, 2, 2>(\ "_bo" #bo)]] [[kernel]] void \
const device itype* out_in [[buffer(0)]],\ winograd_conv_2d_output_transform<itype, bo, 2, 2>( \
device itype* out_out [[buffer(1)]],\ const device itype* out_in [[buffer(0)]], \
const constant MLXConvParams<2>& params [[buffer(2)]],\ device itype* out_out [[buffer(1)]], \
uint3 tid [[threadgroup_position_in_grid]],\ const constant MLXConvParams<2>& params [[buffer(2)]], \
uint3 lid [[thread_position_in_threadgroup]],\ uint3 tid [[threadgroup_position_in_grid]], \
uint3 tgp_per_grid [[threadgroups_per_grid]],\ uint3 lid [[thread_position_in_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]],\ uint3 tgp_per_grid [[threadgroups_per_grid]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]]); uint simd_lane_id [[thread_index_in_simdgroup]]);
#define instantiate_winograd_conv_2d(name, itype) \ // clang-format off
#define instantiate_winograd_conv_2d(name, itype) \
instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \ instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \ instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
instantiate_winograd_conv_2d_output_transform(name, itype, 32) instantiate_winograd_conv_2d_output_transform(name, itype, 32) // clang-format on
// clang-format off
instantiate_winograd_conv_2d(float32, float); instantiate_winograd_conv_2d(float32, float);
instantiate_winograd_conv_2d(float16, half); instantiate_winograd_conv_2d(float16, half); // clang-format on

View File

@@ -1,29 +1,29 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_s( [[kernel]] void copy_s(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[0]); dst[index] = static_cast<U>(src[0]);
} }
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_v( [[kernel]] void copy_v(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[index]); dst[index] = static_cast<U>(src[index]);
} }
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_g_nd1( [[kernel]] void copy_g_nd1(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const size_t& src_stride, constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride); auto src_idx = elem_to_loc_1(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]); dst[index] = static_cast<U>(src[src_idx]);
@@ -31,61 +31,64 @@ template <typename T, typename U>
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_g_nd2( [[kernel]] void copy_g_nd2(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const size_t src_strides[2], constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides); auto src_idx = elem_to_loc_2(index, src_strides);
size_t dst_idx = index.x + (size_t)grid_dim.x * index.y; int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_g_nd3( [[kernel]] void copy_g_nd3(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const size_t src_strides[3], constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides); auto src_idx = elem_to_loc_3(index, src_strides);
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U, int DIM> template <typename T, typename U, int DIM>
[[kernel]] void copy_g_nd( [[kernel]] void copy_g_nd(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const int src_shape[DIM], constant const int* src_shape [[buffer(2)]],
constant const size_t src_strides[DIM], constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides); auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_g( [[kernel]] void copy_g(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const int* src_shape, constant const int* src_shape [[buffer(2)]],
constant const size_t* src_strides, constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim, constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_gg_nd1( [[kernel]] void copy_gg_nd1(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const size_t& src_stride, constant const int64_t& src_stride [[buffer(3)]],
constant const size_t& dst_stride, constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride); auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride); auto dst_idx = elem_to_loc_1(index, dst_stride);
@@ -94,10 +97,10 @@ template <typename T, typename U>
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_gg_nd2( [[kernel]] void copy_gg_nd2(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const size_t src_strides[2], constant const int64_t* src_strides [[buffer(3)]],
constant const size_t dst_strides[2], constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) { uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides); auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides); auto dst_idx = elem_to_loc_2(index, dst_strides);
@@ -106,10 +109,10 @@ template <typename T, typename U>
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_gg_nd3( [[kernel]] void copy_gg_nd3(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const size_t src_strides[3], constant const int64_t* src_strides [[buffer(3)]],
constant const size_t dst_strides[3], constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) { uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides); auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides); auto dst_idx = elem_to_loc_3(index, dst_strides);
@@ -118,11 +121,11 @@ template <typename T, typename U>
template <typename T, typename U, int DIM> template <typename T, typename U, int DIM>
[[kernel]] void copy_gg_nd( [[kernel]] void copy_gg_nd(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const int src_shape[DIM], constant const int* src_shape [[buffer(2)]],
constant const size_t src_strides[DIM], constant const int64_t* src_strides [[buffer(3)]],
constant const size_t dst_strides[DIM], constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) { uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides); auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides); auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
@@ -131,128 +134,122 @@ template <typename T, typename U, int DIM>
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_gg( [[kernel]] void copy_gg(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const int* src_shape, constant const int* src_shape [[buffer(2)]],
constant const size_t* src_strides, constant const int64_t* src_strides [[buffer(3)]],
constant const size_t* dst_strides, constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim, constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) { uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim); auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
#define instantiate_copy(name, itype, otype, ctype) \ #define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] \ template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
[[kernel]] void copy_##ctype<itype, otype>( \ device const itype* src [[buffer(0)]], \
device const itype* src, \ device otype* dst [[buffer(1)]], \
device otype* dst, \
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
#define instantiate_copy_g_dim(name, itype, otype, dims) \ #define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name(name "_" #dims)]] \ template [[host_name(name "_" #dims)]] [[kernel]] void \
[[kernel]] void copy_g_nd<itype, otype, dims>( \ copy_g_nd<itype, otype, dims>( \
device const itype* src, \ device const itype* src [[buffer(0)]], \
device otype* dst, \ device otype* dst [[buffer(1)]], \
constant const int src_shape[dims], \ constant const int* src_shape [[buffer(2)]], \
constant const size_t src_strides[dims], \ constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_" #dims)]] \ template [[host_name("g" name "_" #dims)]] [[kernel]] void \
[[kernel]] void copy_gg_nd<itype, otype, dims>( \ copy_gg_nd<itype, otype, dims>( \
device const itype* src, \ device const itype* src [[buffer(0)]], \
device otype* dst, \ device otype* dst [[buffer(1)]], \
constant const int src_shape[dims], \ constant const int* src_shape [[buffer(2)]], \
constant const size_t src_strides[dims], \ constant const int64_t* src_strides [[buffer(3)]], \
constant const size_t dst_strides[dims], \ constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name(name "_1")]] [[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_1")]] [[kernel]] void \
copy_gg_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("g" name "_2")]] [[kernel]] void \
copy_gg_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("g" name "_3")]] [[kernel]] void \
copy_gg_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g_nd(name, itype, otype) \ #define instantiate_copy_g(name, itype, otype) \
template [[host_name(name "_1")]] \ template [[host_name(name)]] [[kernel]] void copy_g<itype, otype>( \
[[kernel]] void copy_g_nd1<itype, otype>( \ device const itype* src [[buffer(0)]], \
device const itype* src, \ device otype* dst [[buffer(1)]], \
device otype* dst, \ constant const int* src_shape [[buffer(2)]], \
constant const size_t& src_stride, \ constant const int64_t* src_strides [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \ constant const int& ndim [[buffer(5)]], \
template [[host_name(name "_2")]] \ uint3 index [[thread_position_in_grid]], \
[[kernel]] void copy_g_nd2<itype, otype>( \ uint3 grid_dim [[threads_per_grid]]); \
device const itype* src, \ template [[host_name("g" name)]] [[kernel]] void copy_gg<itype, otype>( \
device otype* dst, \ device const itype* src [[buffer(0)]], \
constant const size_t src_strides[2], \ device otype* dst [[buffer(1)]], \
uint2 index [[thread_position_in_grid]], \ constant const int* src_shape [[buffer(2)]], \
uint2 grid_dim [[threads_per_grid]]); \ constant const int64_t* src_strides [[buffer(3)]], \
template [[host_name(name "_3")]] \ constant const int64_t* dst_strides [[buffer(4)]], \
[[kernel]] void copy_g_nd3<itype, otype>( \ constant const int& ndim [[buffer(5)]], \
device const itype* src, \
device otype* dst, \
constant const size_t src_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_1")]] \
[[kernel]] void copy_gg_nd1<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t& src_stride, \
constant const size_t& dst_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name("g" name "_2")]] \
[[kernel]] void copy_gg_nd2<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t src_strides[2], \
constant const size_t dst_strides[2], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("g" name "_3")]] \
[[kernel]] void copy_gg_nd3<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t src_strides[3], \
constant const size_t dst_strides[3], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g(name, itype, otype) \
template [[host_name(name)]] \
[[kernel]] void copy_g<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const int* src_shape, \
constant const size_t* src_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name)]] \
[[kernel]] void copy_gg<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const int* src_shape, \
constant const size_t* src_strides, \
constant const size_t* dst_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]]); uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_all(tname, itype, otype) \ // clang-format off
#define instantiate_copy_all(tname, itype, otype) \
instantiate_copy("scopy" #tname, itype, otype, s) \ instantiate_copy("scopy" #tname, itype, otype, s) \
instantiate_copy("vcopy" #tname, itype, otype, v) \ instantiate_copy("vcopy" #tname, itype, otype, v) \
instantiate_copy_g("gcopy" #tname, itype, otype) \ instantiate_copy_g("gcopy" #tname, itype, otype) \
instantiate_copy_g_nd("gcopy" #tname, itype, otype) instantiate_copy_g_nd("gcopy" #tname, itype, otype) // clang-format on
#define instantiate_copy_itype(itname, itype) \ // clang-format off
instantiate_copy_all(itname ##bool_, itype, bool) \ #define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##uint8, itype, uint8_t) \ instantiate_copy_all(itname ##bool_, itype, bool) \
instantiate_copy_all(itname ##uint16, itype, uint16_t) \ instantiate_copy_all(itname ##uint8, itype, uint8_t) \
instantiate_copy_all(itname ##uint32, itype, uint32_t) \ instantiate_copy_all(itname ##uint16, itype, uint16_t) \
instantiate_copy_all(itname ##uint64, itype, uint64_t) \ instantiate_copy_all(itname ##uint32, itype, uint32_t) \
instantiate_copy_all(itname ##int8, itype, int8_t) \ instantiate_copy_all(itname ##uint64, itype, uint64_t) \
instantiate_copy_all(itname ##int16, itype, int16_t) \ instantiate_copy_all(itname ##int8, itype, int8_t) \
instantiate_copy_all(itname ##int32, itype, int32_t) \ instantiate_copy_all(itname ##int16, itype, int16_t) \
instantiate_copy_all(itname ##int64, itype, int64_t) \ instantiate_copy_all(itname ##int32, itype, int32_t) \
instantiate_copy_all(itname ##float16, itype, half) \ instantiate_copy_all(itname ##int64, itype, int64_t) \
instantiate_copy_all(itname ##float32, itype, float) \ instantiate_copy_all(itname ##float16, itype, half) \
instantiate_copy_all(itname ##float32, itype, float) \
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \ instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
instantiate_copy_all(itname ##complex64, itype, complex64_t) instantiate_copy_all(itname ##complex64, itype, complex64_t)
@@ -268,4 +265,4 @@ instantiate_copy_itype(int64, int64_t)
instantiate_copy_itype(float16, half) instantiate_copy_itype(float16, half)
instantiate_copy_itype(float32, float) instantiate_copy_itype(float32, float)
instantiate_copy_itype(bfloat16, bfloat16_t) instantiate_copy_itype(bfloat16, bfloat16_t)
instantiate_copy_itype(complex64, complex64_t) instantiate_copy_itype(complex64, complex64_t) // clang-format on

View File

@@ -14,3 +14,5 @@ static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
static MTL_CONST constexpr int REDUCE_N_READS = 16; static MTL_CONST constexpr int REDUCE_N_READS = 16;
static MTL_CONST constexpr int SOFTMAX_N_READS = 4; static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096; static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
static MTL_CONST constexpr int RMS_N_READS = 4;
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;

View File

@@ -0,0 +1,89 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_math>
// Original license copied below:
// Copyright (c) 2015-2023 Norbert Juffa
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/* Compute exponential base e minus 1. Maximum ulp error = 0.997458
i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1.
Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5).
With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy,
when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r.
NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2)
*/
float expm1f_scaled_unchecked(float a, float b) {
float f, j, r, s, t, u, v, x, y;
int i;
// exp(a) = 2**i * exp(f); i = rintf (a / log(2))
j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23
j = j - 12582912.0f; // 0x1.8p23
i = (int)j;
f = fma(j, -6.93145752e-1f, a);
// approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
s = f * f;
if (a == 0.0f)
s = a; // ensure -0 is passed through
// err = 0.997458 ulp1 = 11081805
r = 1.97350979e-4f; // 0x1.9de000p-13
r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10
r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7
r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5
r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3
r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2
u = (j == 1) ? (f + 0.5f) : f;
v = fma(r, s, u);
s = 0.5f * b;
t = ldexp(s, i);
y = t - s;
x = (t - y) - s; // double-float canonicalization of difference
r = fma(v, t, x) + y;
r = r + r;
if (j == 0)
r = v;
if (j == 1)
r = v + v;
return r;
}
/* Compute exponential base e minus 1. max ulp err = 0.99746 */
float expm1f(float a) {
float r;
r = expm1f_scaled_unchecked(a, 1.0f);
/* handle severe overflow and underflow */
if (abs(a - 1.0f) > 88.0f) {
r = fma(r, r, -1.0f);
}
return r;
}

Some files were not shown because too many files have changed in this diff Show More