Compare commits

..

54 Commits

Author SHA1 Message Date
Awni Hannun
d07e295c62 bumpity bump (#987) 2024-04-11 12:48:52 -07:00
Angelos Katharopoulos
dce4bd74a4 Add ArrayDesc destructor to avoid possible stack overflow (#982) 2024-04-11 11:37:02 -07:00
Nripesh Niketan
ffff671273 Update pre-commit hooks (#984) 2024-04-11 07:27:53 -07:00
Awni Hannun
12d4507ee3 Explicit barriers with concurrent dispatch (#977) 2024-04-10 21:45:31 -07:00
Awni Hannun
8580d997ff Try a stack-based DFS for eval (#980)
* rebase

* nit

* fix eval in vmap
2024-04-10 17:05:13 -07:00
Shiyu
061cf9a4ce Upsample with bicubic interpolation (#967) 2024-04-10 15:47:22 -07:00
Awni Hannun
99abb9eff4 Async eval (#972) 2024-04-09 18:34:00 -07:00
Luca Arnaboldi
fffe072028 Implementation of mlx.random.multivariate_normal (#502) (#877)
* Implementation of mlx.random.multivariate_normal (#502)

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Updated typo in docstring

* Restricted multivariate_normal to  float32

* Generic mean and variance shapes

* Review edits

* Update mlx/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Test for ndim of mean and cov

* nits

* smaller size for test

* fix broadcasted sampling

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-09 13:50:12 -07:00
Abe Leininger
a1a31eed27 Add mx.meshgrid (#961) 2024-04-09 11:43:08 -07:00
Awni Hannun
ae812350f9 use string (#976) 2024-04-09 11:22:00 -07:00
Awni Hannun
b63ef10a7f Extensions (#962)
* start to fix extensions

* mostly fixed extensions

* fix extension build

* couple more nits
2024-04-09 08:50:36 -07:00
Awni Hannun
42afe27e12 std and expm1 (#973)
* std and expm1

* actually add expm1

* fix linux

* fix vjp

* relax tol for linux test

* Add it to the compilable primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-08 14:26:01 -07:00
Awni Hannun
76e63212ff Enable bfloat scan (#974)
* enable bfloat scan
* fix tests
2024-04-08 12:29:19 -07:00
Awni Hannun
aac2f9fb61 Improve profiling with gpu tracing (#969)
* improve profiling with gpu tracing

* fix for linux

* nit

* doc fix

* fix example
2024-04-07 21:47:43 -07:00
Awni Hannun
bddf23f175 patch bump (#956) 2024-04-04 11:56:37 -07:00
Awni Hannun
039da779d1 No quant reshape (#957)
* precise option on cpu

* remove print

* remove reshape in quant matmul

* no quant reshape
2024-04-04 11:52:12 -07:00
Awni Hannun
d88d2124b5 segfaut layer norm grad (#955) 2024-04-04 10:59:15 -07:00
Awni Hannun
e142aaf8a1 Option for precise softmax (#953)
* precise softmax

* Add an equivalency check

* Make the threadgroup memory definition fixed

* precise cpu softmax

* precise option on cpu

* remove print

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-04 08:32:35 -07:00
AmirHossein_Razlighi
0caf35f4b8 Better exceptions in case of invalid operations on mlx.core.array (#910) (#926)
* Nicer exceptions for ops on non-arrays
2024-04-02 21:11:24 -07:00
Angelos Katharopoulos
3fc993f82d Properly handle negative axes in python vmap (#944) 2024-04-02 18:07:23 -07:00
Awni Hannun
741eb28443 fix a couple bugs (#952) 2024-04-02 12:07:41 -07:00
Angelos Katharopoulos
1a87dc5ea8 Fix compile fusion for multi-output edge cases (#950)
* Fix compile fusion for multi-output edge cases

* Add a test for multi-output compile
2024-04-02 08:42:31 -07:00
Awni Hannun
2427fa171e Fix cpu compile (#934)
* fix one cpu bug, test for another

* format hooks

* simplify contiguity check for cpu compile

* fix

* add back donation

* comment
2024-04-01 17:37:12 -07:00
Jagrit Digani
639e06e1f3 Indexing bug fix (#947)
* Fix axes accounting

* Add tests
2024-04-01 12:18:50 -07:00
Angelos Katharopoulos
02fedbf1da Fix array initialization from list (#942)
* Fix array initialization from list

* Change the error message in the test
2024-04-01 06:27:52 -07:00
Angelos Katharopoulos
110d9b149d Layer norm grad fix donation bug (#941)
* add layer norm grad test

* Fix donation bug in layernorm vjp

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-01 06:15:50 -07:00
Angelos Katharopoulos
9cbff5ec1d Fix typo in qmm check (#940) 2024-03-31 19:15:44 -07:00
Suvan Kumar
433c0206b0 Update saving_and_loading.rst (#929)
Update saving / load docs.
2024-03-30 14:30:06 -07:00
Awni Hannun
8915901966 Donation bug (#933)
* donation

* buf

* fix bug in softmax

* comment

* remove print
2024-03-30 10:08:54 -07:00
AmirHossein_Razlighi
f48bc496c7 Comparing python objects (such as list/tuple) with mlx.core.array (#920)
* add implicit conversion of list to array for equality constraint

* add tests for array equality

* add test for tuple and array equality

* return False if __eq__ arg is list or tuple

* write tests for equality

* update the rule of comparison for __ge__/__gt__/__lt__/__le__

* add a helper function for detecting mlx.core.array

* return true in case fo inequality

* debug minor issue regarding detecting mlx array

* add tests for inequality comparisons

* add name for contribution

* reformat files using pre-commit

* update tests for float

* update tests for inequality

* raise exception in case of invalid comparisons

* use isinstance instead of string comparison

* replace "is_convirtable_to_array" with previous logic

* remove throwing exceptions for other operations

* just a comment

* minor changes for efficiency

* optimize a utils function

* change the function name

* Update ACKNOWLEDGMENTS.md

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-29 06:52:30 -07:00
Cheng
913b19329c Add missing && when forwarding args (#925)
Without the && args would be copied and perfect forwarding won't work.
2024-03-29 06:48:29 -07:00
Awni Hannun
d8cb3128f6 bump (#924)
* bump

* fix version
2024-03-28 16:14:55 -07:00
Angelos Katharopoulos
5f9ba3019f Fix qmm_t for unaligned cases (#923) 2024-03-28 15:34:57 -07:00
Cheng
46caf0bef0 Remove unnecessary string copies (#891)
1. Use string_view instead of string when there is no need for copy.
2. Otherwise move string when possible.
2024-03-28 13:14:59 -07:00
Jack Mousseau
45f636e759 Add Metal debug option and capture functions (#707)
* Add Metal debug option and capture functions

* Add brief Metal debugger documentation

* doc nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-28 09:40:31 -07:00
Cheng
a7b404ff53 Use uintptr_t instead of size_t to store funtion id (#916)
Also does some small cleanup of the compile cache code.
2024-03-28 06:37:59 -07:00
Angelos Katharopoulos
c4fd0e5ede Fixes #918 bug in compile_tests (#919) 2024-03-27 22:37:37 -07:00
Cheng
bab5386306 Make ops aware of rvalues: astype/as_strided/copy/full (#895)
When compositing transforms lots of temporary of arrays will be created
and passed to next primitive, and by making ops accepting args by value
we can avoid lots of copies of temporary arrays.
2024-03-27 22:35:55 -07:00
Angelos Katharopoulos
aca7584635 Fix OOB read in qmv when non-divisible by blocksize (#917) 2024-03-27 22:18:35 -07:00
AmirHossein_Razlighi
d611251502 Support Chaining for some of functionalities of nn.Module (#885) (#897)
* add chaining support for some of the functionalities of "nn.Module"

* reformat

* change the return types

* remove return types

* add return type with forward referencing

* add tests for chaining

* add name to contributors

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update docstring

* update docstrings

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-27 19:58:29 -07:00
Cheng
f30b659291 Make MLX build on x64 macOS (#901)
The arm64 macbook pros are heavy and I usually care my intel one for
mobile, it would be nice if I can play with MLX on it.

To build with x64, user must pass `MLX_ENABLE_X64_MAC` to cmake:
CMAKE_ARGS='-DMLX_ENABLE_X64_MAC=ON' python setup.py
2024-03-27 06:14:29 -07:00
Cheng
90dfa43ff1 Don't use make_unique to create shared_ptr (#902)
The code compiled because shared_ptr's constructor actually accepts
unique_ptr.
2024-03-27 06:13:29 -07:00
Awni Hannun
dc175f08d3 Fix race in multi-stream eval (#911)
* maybe fix race

* comment
2024-03-26 16:36:36 -07:00
Angelos Katharopoulos
29221fa238 Implement vjps for some primitives in the fast namespace (#883)
* Implement rope vjp in terms of rope
* RMSNormVJP primitive and kernel
* Add LayerNormVJP primitive and kernel
2024-03-26 16:35:34 -07:00
Cheng
a789685c63 Remove duplicate defines of StreamOrDevice and is_big_endian (#892) 2024-03-26 15:15:11 -07:00
Jagrit Digani
240d10699c Implement negative padding in conv with slicing (#907)
* Implement negative padding with slicing

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni@apple.com>

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-26 14:59:19 -07:00
Jagrit Digani
925014b661 Fix multiblock sort limits (#906)
* Fix multiblock sort limits

* Fix metal validation error
2024-03-26 14:00:00 -07:00
Abdussamet Türker
5611e1a95e Fix unsqueeze with None (#899)
* Fix unsqueeze with None

* Clean unnecessary files
2024-03-26 13:59:44 -07:00
Awni Hannun
570f2bf29e pick up preivously set attributes (#905) 2024-03-26 11:19:59 -07:00
Angelos Katharopoulos
9948eddf11 Fix nan and improve speed for qvm (#903) 2024-03-26 10:41:45 -07:00
Luca Arnaboldi
a3ee03da01 Fixing random.normal for half-precision dtype #642 (#904)
* Fixing random.normal for half-precision dtype #642

* Update python/tests/test_random.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-26 09:58:27 -07:00
Cheng
28fcd2b519 Add missing && when forwarding args (#894)
Without the && args would be copied and perfect forwarding won't work.

Also add template utils to make sure the function only forwards array
and not vector<array>.
2024-03-25 14:55:54 -07:00
Jack Mousseau
8e686764ac Ensure shape dimensions are within supported integer range (#566) (#704)
* Ensure shape dimensions are within supported integer range (#566)

* fix build

* fix rebase bug

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 13:29:45 -07:00
Daniel Strobusch
479051ce1c add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 12:32:59 -07:00
141 changed files with 5518 additions and 1549 deletions

View File

@@ -1,11 +1,11 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v17.0.6
rev: v18.1.3
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.2.0
rev: 24.3.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort

View File

@@ -15,6 +15,8 @@ MLX was developed with contributions from the following individuals:
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>

View File

@@ -15,31 +15,33 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.8.1)
set(MLX_VERSION 0.10.0)
endif()
# --------------------- Processor tests -------------------------
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
set(MLX_BUILD_ARM OFF)
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
message(FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
message(WARNING
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, "
" make sure you are building for arm64.")
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC)
message(FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
else()
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
set(MLX_BUILD_METAL OFF)
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
@@ -64,8 +66,14 @@ endif()
if (MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif (MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
if (MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
# Throw an error if xcrun not found
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_VERSION
@@ -108,7 +116,27 @@ if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
#set(BLA_VENDOR Generic)
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif()
# Search and link with lapack.
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version
# of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
@@ -122,17 +150,6 @@ else()
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)

View File

@@ -17,14 +17,13 @@
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " \
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
<< std::flush << std::setprecision(5) \
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
template <typename F, typename... Args>
double time_fn(F fn, Args... args) {
double time_fn(F fn, Args&&... args) {
// warmup
for (int i = 0; i < 5; ++i) {
eval(fn(std::forward<Args>(args)...));

View File

@@ -0,0 +1,41 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def layer_norm(x, w, b, eps):
ot = x.dtype
x = x.astype(mx.float32)
mu = mx.mean(x, -1, keepdims=True)
v = mx.var(x, -1, keepdims=True)
return (x - mu) * mx.rsqrt(v + eps) * w + b
def time_layer_norm():
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, b, y)
def layer_norm_loop(g, x, w, b):
gx, gw, gb = x, w, b
for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb
time_fn(layer_norm_loop, g1, x, w, b)
time_fn(layer_norm_loop, g2, x, w, b)
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
if __name__ == "__main__":
time_layer_norm()

View File

@@ -0,0 +1,39 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def rms_norm(x, w, eps):
ot = x.dtype
x = x.astype(mx.float32)
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return (x * n).astype(ot) * w
def time_rms_norm():
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1))
g2 = mx.grad(f2, argnums=(0, 1))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x, w):
gx, gw = x, w
for _ in range(32):
gx, gw = g(gx, gw, y)
return gx, gw
time_fn(rms_norm_loop, g1, x, w)
time_fn(rms_norm_loop, g2, x, w)
time_fn(rms_norm_loop, mx.compile(g1), x, w)
time_fn(rms_norm_loop, mx.compile(g2), x, w)
if __name__ == "__main__":
time_rms_norm()

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 746 KiB

View File

@@ -1,24 +1,16 @@
Developer Documentation
=======================
MLX provides a open and flexible backend to which users may add operations
and specialized implementations without much hassle. While the library supplies
efficient operations that can be used and composed for any number of
applications, there may arise cases where new functionalities or highly
optimized implementations are needed. For such cases, you may design and
implement your own operations that link to and build on top of :mod:`mlx.core`.
We will introduce the inner-workings of MLX and go over a simple example to
learn the steps involved in adding new operations to MLX with your own CPU
and GPU implementations.
You can extend MLX with custom operations on the CPU or GPU. This guide
explains how to do that with a simple example.
Introducing the Example
-----------------------
Let's say that you would like an operation that takes in two arrays,
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
respectively, and then adds them together to get the result
``z = alpha * x + beta * y``. Well, you can very easily do that by just
writing out a function as follows:
Let's say you would like an operation that takes in two arrays, ``x`` and
``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
and then adds them together to get the result ``z = alpha * x + beta * y``.
You can do that in MLX directly:
.. code-block:: python
@@ -27,44 +19,35 @@ writing out a function as follows:
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
This function performs that operation while leaving the implementations and
differentiation to MLX.
This function performs that operation while leaving the implementation and
function transformations to MLX.
However, you work with vector math libraries often and realize that the
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``.
You would really like the part of your applications that does this operation
on the CPU to be very fast - so you decide that you want it to rely on the
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
our assumptions on to you, let's also assume that you want to learn how to add
your own implementation for the gradients of your new operation while going
over the ins-and-outs of the MLX framework.
However you may need to customize the underlying implementation, perhaps to
make it faster or for custom differentiation. In this tutorial we will go
through adding custom extensions. It will cover:
Well, what a coincidence! You are in the right place. Over the course of this
example, we will learn:
* The structure of the MLX library from the frontend API to the backend implementations.
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed).
* How to implement your own GPU implementation using metal.
* How to add your own ``vjp`` and ``jvp``.
* How to build your implementations, link them to MLX, and bind them to python.
* The structure of the MLX library.
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
* Implementing a GPU operation using metal.
* Adding the ``vjp`` and ``jvp`` function transformation.
* Building a custom extension and binding it to python.
Operations and Primitives
-------------------------
In one sentence, operations in MLX build the computation graph, and primitives
provide the rules for evaluation and transformations of said graph. Let's start
by discussing operations in more detail.
Operations in MLX build the computation graph. Primitives provide the rules for
evaluating and transforming the graph. Let's start by discussing operations in
more detail.
Operations
^^^^^^^^^^^
Operations are the frontend functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these
operations in the Python API (:ref:`ops`).
Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``,
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the
C++ API:
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
C++:
.. code-block:: C++
@@ -83,10 +66,7 @@ C++ API:
StreamOrDevice s = {} // Stream on which to schedule the operation
);
This operation itself can call other operations within it if needed. So, the
simplest way to go about implementing this operation would be do so in terms
of existing operations.
The simplest way to this operation is in terms of existing operations:
.. code-block:: C++
@@ -100,25 +80,23 @@ of existing operations.
// Scale x and y on the provided stream
auto ax = multiply(array(alpha), x, s);
auto by = multiply(array(beta), y, s);
// Add and return
return add(ax, by, s);
}
However, as we discussed earlier, this is not our goal. The operations themselves
do not contain the implementations that act on the data, nor do they contain the
rules of transformations. Rather, they are an easy to use interface that build
on top of the building blocks we call :class:`Primitive`.
The operations themselves do not contain the implementations that act on the
data, nor do they contain the rules of transformations. Rather, they are an
easy to use interface that use :class:`Primitive` building blocks.
Primitives
^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create an output given a set of input :class:`array` . Further,
a :class:`Primitive` is a class that contains rules on how it is evaluated
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and
``jvp``. These words on their own can be a bit abstract, so lets take a step
back and go to our example to give ourselves a more concrete image.
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create outputs arrays given a input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
more concrete:
.. code-block:: C++
@@ -134,11 +112,15 @@ back and go to our example to give ourselves a more concrete image.
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
void eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
/** The Jacobian-vector product. */
array jvp(
std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
@@ -147,7 +129,8 @@ back and go to our example to give ourselves a more concrete image.
std::vector<array> vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) override;
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
/**
* The primitive must know how to vectorize itself across
@@ -155,7 +138,7 @@ back and go to our example to give ourselves a more concrete image.
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<array, int> vmap(
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
@@ -175,22 +158,22 @@ back and go to our example to give ourselves a more concrete image.
void eval(const std::vector<array>& inputs, array& out);
};
The :class:`Axpby` class derives from the base :class:`Primitive` class and
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and
``beta`` as parameters. It then provides implementations of how the array ``out``
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`.
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
implementations of how the output array is produced given the inputs through
:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
:meth:`Axpby::vmap`.
Using the Primitives
^^^^^^^^^^^^^^^^^^^^^
Using the Primitive
^^^^^^^^^^^^^^^^^^^
Operations can use this :class:`Primitive` to add a new :class:`array` to
the computation graph. An :class:`array` can be constructed by providing its
data type, shape, the :class:`Primitive` that computes it, and the
:class:`array` inputs that are passed to the primitive.
Operations can use this :class:`Primitive` to add a new :class:`array` to the
computation graph. An :class:`array` can be constructed by providing its data
type, shape, the :class:`Primitive` that computes it, and the :class:`array`
inputs that are passed to the primitive.
Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
.. code-block:: C++
@@ -223,7 +206,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_unique<Axpby>(to_stream(s), alpha, beta),
std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
@@ -238,27 +221,26 @@ This operation now handles the following:
Implementing the Primitive
--------------------------
No computation happens when we call the operation alone. In effect, the
operation only builds the computation graph. When we evaluate the output
array, MLX schedules the execution of the computation graph, and calls
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the
stream/device specified by the user.
No computation happens when we call the operation alone. The operation only
builds the computation graph. When we evaluate the output array, MLX schedules
the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
.. warning::
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
no memory has been allocated for the output array. It falls on the implementation
of these functions to allocate memory as needed
of these functions to allocate memory as needed.
Implementing the CPU Backend
Implementing the CPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by trying to implement a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
Let's start by implementing a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
pointwise. This is captured in the templated function :meth:`axpby_impl`.
Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
point-wise. This is captured in the templated function :meth:`axpby_impl`.
.. code-block:: C++
@@ -296,19 +278,19 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
}
}
Now, we would like our implementation to be able to do this pointwise operation
for all incoming floating point arrays. Accordingly, we add dispatches for
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error
if we encounter an unexpected type.
Our implementation should work for all incoming floating point arrays.
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
``complex64``. We throw an error if we encounter an unexpected type.
.. code-block:: C++
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
void Axpby::eval(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
@@ -321,28 +303,26 @@ if we encounter an unexpected type.
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
"[Axpby] Only supports floating point types.");
}
}
We have a fallback implementation! Now, to do what we are really here to do.
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_
framework? Well, there are 3 complications to keep in mind:
This is good as a fallback implementation. We can use the ``axpby`` routine
provided by the Accelerate_ framework for a faster implementation in certain
cases:
#. Accelerate does not provide implementations of ``axpby`` for half precision
floats. We can only direct to it for ``float32`` types
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements
have fixed strides between them. Possibly due to broadcasts and transposes,
we aren't guaranteed that the inputs fit this requirement. We can
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or
column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace.
MLX expects to write out the answer to a new array. We must copy the elements
of ``y`` into the output array and use that as an input to ``axpby``
floats. We can only use it for ``float32`` types.
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
elements have fixed strides between them. We only direct to Accelerate
if both ``x`` and ``y`` are row contiguous or column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
MLX expects to write the output to a new array. We must copy the elements
of ``y`` into the output and use that as an input to ``axpby``.
Let's write out an implementation that uses Accelerate in the right conditions.
It must simply allocate data for the output, copy elements of ``y`` into it,
and then call the :meth:`catlas_saxpby` from accelerate.
Let's write an implementation that uses Accelerate in the right conditions.
It allocates data for the output, copies ``y`` into it, and then calls the
:func:`catlas_saxpby` from accelerate.
.. code-block:: C++
@@ -356,17 +336,7 @@ and then call the :meth:`catlas_saxpby` from accelerate.
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
// This specialization requires both x and y be contiguous in the same mode
// i.e: corresponding linear indices in both point to corresponding elements
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
@@ -389,18 +359,20 @@ and then call the :meth:`catlas_saxpby` from accelerate.
/* INCY = */ 1);
}
Great! But what about the inputs that do not fit the criteria for accelerate?
Luckily, we can always just direct back to :meth:`Axpby::eval`.
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
For inputs that do not fit the criteria for accelerate, we fall back to
:meth:`Axpby::eval`. With this in mind, let's finish our
:meth:`Axpby::eval_cpu`.
.. code-block:: C++
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
@@ -410,35 +382,33 @@ With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
return;
}
// Fall back to common backend if specializations are not available
eval(inputs, out);
// Fall back to common back-end if specializations are not available
eval(inputs, outputs);
}
We have now hit a milestone! Just this much is enough to run the operation
:meth:`axpby` on a CPU stream!
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
If you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
Implementing the GPU Backend
Implementing the GPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Apple silicon devices address their GPUs using the Metal_ shading language, and
all GPU kernels in MLX are written using metal.
Apple silicon devices address their GPUs using the Metal_ shading language, and
GPU kernels in MLX are written using Metal.
.. note::
Here are some helpful resources if you are new to metal!
Here are some helpful resources if you are new to Metal:
* A walkthrough of the metal compute pipeline: `Metal Example`_
* Documentation for metal shading language: `Metal Specification`_
* Using metal from C++: `Metal-cpp`_
Let's keep the GPU algorithm simple. We will launch exactly as many threads
as there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the pointwise operation, and then update its assigned
element in the output.
Let's keep the GPU kernel simple. We will launch exactly as many threads as
there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the point-wise operation, and update its assigned
element in the output.
.. code-block:: C++
@@ -457,15 +427,14 @@ element in the output.
// Convert linear indices to offsets in array
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
// Do the operation and update the output
out[index] =
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
We then need to instantiate this template for all floating point types and give
each instantiation a unique host name so we can identify the right kernel for
each data type.
each instantiation a unique host name so we can identify it.
.. code-block:: C++
@@ -488,29 +457,21 @@ each data type.
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
will see later in :ref:`Building with CMake`. In the following example, we
assume that the library ``mlx_ext.metallib`` will always be co-located with
the executable/ shared-library calling the :meth:`register_library` function.
The :meth:`register_library` function takes the library's name and potential
path (or in this case, a function that can produce the path of the metal
library) and tries to load that library if it hasn't already been registered
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
it is important to package your C++ library with the metal library. We will
go over this process in more detail later.
The logic to determine the kernel, set the inputs, resolve the grid dimensions
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
below.
.. code-block:: C++
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
@@ -518,10 +479,10 @@ below.
// We get the needed metal device using the stream
auto& d = metal::device(s.device);
// Allocate output memory
// Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Resolve name of kernel (corresponds to axpby.metal)
// Resolve name of kernel
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
@@ -552,7 +513,7 @@ below.
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim
// Encode shape, strides and ndim
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
@@ -575,28 +536,25 @@ below.
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
A few things to note about MLX and metal before moving on. MLX keeps track
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder`
to give us the active metal compute command encoder instead of building a
new one and calling :meth:`compute_encoder->end_encoding` at the end.
MLX keeps adding kernels (compute pipelines) to the active command encoder
until some specified limit is hit or the compute encoder needs to be flushed
for synchronization. MLX also handles enqueuing and committing the associated
command buffers as needed. We suggest taking a deeper dive into
:class:`metal::Device` if you would like to study this routine further.
A few things to note about MLX and Metal before moving on. MLX keeps track of
the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
associated. We rely on :meth:`d.get_command_encoder` to give us the active
metal compute command encoder instead of building a new one and calling
:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
pipelines) to the active command buffer until some specified limit is hit or
the command buffer needs to be flushed for synchronization.
Primitive Transforms
^^^^^^^^^^^^^^^^^^^^^
Now that we have come this far, let's also learn how to add implementations to
transformations in a :class:`Primitive`. These transformations can be built on
top of our operations, including the one we just defined now. Which then gives
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
Next, let's add implementations for transformations in a :class:`Primitive`.
These transformations can be built on top of other operations, including the
one we just defined:
.. code-block:: C++
/** The Jacobian-vector product. */
array Axpby::jvp(
std::vector<array> Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
@@ -611,12 +569,12 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return multiply(scale_arr, tangents[0], stream());
return {multiply(scale_arr, tangents[0], stream())};
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
}
}
@@ -625,34 +583,35 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) {
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<int>& /* unused */) {
// Reverse mode diff
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotan.dtype());
vjps.push_back(multiply(scale_arr, cotan, stream()));
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
}
return vjps;
}
Finally, you need not have a transformation fully defined to start using your
own :class:`Primitive`.
Note, a transformation does not need to be fully defined to start using
the :class:`Primitive`.
.. code-block:: C++
/** Vectorize primitive along given axis */
std::pair<array, int> Axpby::vmap(
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation.");
throw std::runtime_error("[Axpby] vmap not implemented.");
}
Building and Binding
--------------------
Let's look at the overall directory structure first.
Let's look at the overall directory structure first.
| extensions
| ├── axpby
@@ -666,40 +625,39 @@ Let's look at the overall directory structure first.
| └── setup.py
* ``extensions/axpby/`` defines the C++ extension library
* ``extensions/mlx_sample_extensions`` sets out the structure for the
associated python package
* ``extensions/bindings.cpp`` provides python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
python bindings
* ``extensions/mlx_sample_extensions`` sets out the structure for the
associated Python package
* ``extensions/bindings.cpp`` provides Python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
Python bindings
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
the python package
the Python package
Binding to Python
^^^^^^^^^^^^^^^^^^
We use PyBind11_ to build a Python API for the C++ library. Since bindings for
We use nanobind_ to build a Python API for the C++ library. Since bindings for
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
already provided, adding our :meth:`axpby` is simple!
already provided, adding our :meth:`axpby` is simple.
.. code-block:: C++
PYBIND11_MODULE(mlx_sample_extensions, m) {
m.doc() = "Sample C++ and metal extensions for MLX";
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
py::pos_only(),
"alpha"_a,
"beta"_a,
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
nb::kw_only(),
"stream"_a = nb::none(),
R"(
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
@@ -711,17 +669,17 @@ already provided, adding our :meth:`axpby` is simple!
Returns:
array: ``alpha * x + beta * y``
)pbdoc");
)");
}
Most of the complexity in the above example comes from additional bells and
Most of the complexity in the above example comes from additional bells and
whistles such as the literal names and doc-strings.
.. warning::
:mod:`mlx.core` needs to be imported before importing
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to
ensure that the casters for :mod:`mlx.core` components like
:mod:`mlx.core` must be imported before importing
:mod:`mlx_sample_extensions` as defined by the nanobind module above to
ensure that the casters for :mod:`mlx.core` components like
:class:`mlx.core.array` are available.
.. _Building with CMake:
@@ -729,8 +687,8 @@ whistles such as the literal names and doc-strings.
Building with CMake
^^^^^^^^^^^^^^^^^^^^
Building the C++ extension library itself is simple, it only requires that you
``find_package(MLX CONFIG)`` and then link it to your library.
Building the C++ extension library only requires that you ``find_package(MLX
CONFIG)`` and then link it to your library.
.. code-block:: cmake
@@ -752,12 +710,12 @@ Building the C++ extension library itself is simple, it only requires that you
# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)
We also need to build the attached metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package).
We also need to build the attached Metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package).
Here is what that looks like in practice!
Here is what that looks like in practice:
.. code-block:: cmake
@@ -779,27 +737,29 @@ Here is what that looks like in practice!
endif()
Finally, we build the Pybind11_ bindings
Finally, we build the nanobind_ bindings
.. code-block:: cmake
pybind11_add_module(
mlx_sample_extensions
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
nanobind_add_module(
_ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
target_link_libraries(_ext PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif()
Building with ``setuptools``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once we have set out the CMake build rules as described above, we can use the
build utilities defined in :mod:`mlx.extension` for a simple build process.
build utilities defined in :mod:`mlx.extension`:
.. code-block:: python
.. code-block:: python
from mlx import extension
from setuptools import setup
@@ -809,48 +769,50 @@ build utilities defined in :mod:`mlx.extension` for a simple build process.
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
cmdclass={"build_ext": extension.CMakeBuild},
packages = ["mlx_sample_extensions"],
package_dir = {"": "mlx_sample_extensions"},
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
packages=["mlx_sample_extensions"],
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev":[]},
zip_safe=False,
python_requires=">=3.7",
python_requires=">=3.8",
)
.. note::
We treat ``extensions/mlx_sample_extensions`` as the package directory
even though it only contains a ``__init__.py`` to ensure the following:
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
You can build inplace for development using
* :mod:`mlx.core` must be imported before importing :mod:`_ext`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
To build the package, first install the build dependencies with ``pip install
-r requirements.txt``. You can then build inplace for development using
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
This will result in a directory structure as follows:
This results in the directory structure:
| extensions
| ├── mlx_sample_extensions
| │ ├── __init__.py
| │ ├── libmlx_ext.dylib # C++ extension library
| │ ├── mlx_ext.metallib # Metal library
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
| │ └── _ext.cpython-3x-darwin.so # Python Binding
| ...
When you try to install using the command ``python -m pip install .``
(in ``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
copied along with the python binding since they are specified as ``package_data``.
When you try to install using the command ``python -m pip install .`` (in
``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
copied along with the Python binding since they are specified as
``package_data``.
Usage
-----
After installing the extension as described above, you should be able to simply
import the python package and play with it as you would any other MLX operation!
After installing the extension as described above, you should be able to simply
import the Python package and play with it as you would any other MLX operation.
Let's looks at a simple script and it's results!
Let's look at a simple script and its results:
.. code-block:: python
@@ -874,12 +836,12 @@ Output:
c correctness: True
Results
^^^^^^^^^^^^^^^^
^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we defined at first on the CPU.
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we first defined on the CPU.
.. code-block:: python
.. code-block:: python
import mlx.core as mx
from mlx_sample_extensions import axpby
@@ -898,7 +860,7 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
alpha = 4.0
beta = 2.0
mx.eval((x, y))
mx.eval(x, y)
def bench(f):
# Warm up
@@ -919,30 +881,23 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
Results:
.. code-block::
Simple axpby: 0.114 s | Custom axpby: 0.109 s
We see some modest improvements right away!
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
modest improvements right away!
This operation is now good to be used to build other operations, in
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
:meth:`grad`!
:meth:`grad`.
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx <code>`_.
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/
.. _nanobind: https://nanobind.readthedocs.io/en/latest/

View File

@@ -0,0 +1,69 @@
Metal Debugger
==============
.. currentmodule:: mlx.core
Profiling is a key step for performance optimization. You can build MLX with
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
* Records source during Metal compilation, for later inspection while
debugging.
* Labels Metal objects such as command queues, improving capture readability.
To build with debugging enabled in Python prepend
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
work.
.. note::
To capture a GPU trace you must run the application with
``MTL_CAPTURE_ENABLED=1``.
.. code-block:: python
import mlx.core as mx
a = mx.random.uniform(shape=(512, 512))
b = mx.random.uniform(shape=(512, 512))
mx.eval(a, b)
trace_file = "mlx_trace.gputrace"
if not mx.metal.start_capture(trace_file):
print("Make sure to run with MTL_CAPTURE_ENABLED=1 and "
f"that the path {trace_file} does not already exist.")
exit(1)
for _ in range(10):
mx.eval(mx.add(a, b))
mx.metal.stop_capture()
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
has a great overview of all operations. Checkout the `Metal debugger
documentation`_ for more information.
.. image:: ../_static/metal_debugger/capture.png
:class: dark-light
Xcode Workflow
--------------
You can skip saving to a path by running within Xcode. First, generate an
Xcode project using CMake.
.. code-block::
mkdir build && cd build
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
open mlx.xcodeproj
Select the ``metal_capture`` example schema and run.
.. image:: ../_static/metal_debugger/schema.png
:class: dark-light
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger

View File

@@ -58,6 +58,7 @@ are the CPU and GPU.
:maxdepth: 1
python/array
python/data_types
python/devices_and_streams
python/ops
python/random
@@ -81,3 +82,4 @@ are the CPU and GPU.
:maxdepth: 1
dev/extensions
dev/metal_debugger

View File

@@ -155,6 +155,8 @@ should point to the path to the built metal library.
- ON
* - MLX_BUILD_PYTHON_BINDINGS
- OFF
* - MLX_METAL_DEBUG
- OFF
.. note::

View File

@@ -19,7 +19,6 @@ Array
array.ndim
array.shape
array.size
Dtype
array.abs
array.all
array.any
@@ -32,7 +31,6 @@ Array
array.cumsum
array.diag
array.diagonal
array.dtype
array.exp
array.flatten
array.log

View File

@@ -1,7 +1,5 @@
.. _data_types:
:orphan:
Data Types
==========
@@ -56,3 +54,15 @@ The default floating point type is ``float32`` and the default integer type is
* - ``complex64``
- 8
- 64-bit complex float
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
documentation for more information. Use :func:`issubdtype` to determine if one
``dtype`` (or category) is a subtype of another category.
.. autosummary::
:toctree: _autosummary
Dtype
DtypeCategory
issubdtype

View File

@@ -3,7 +3,7 @@ Metal
.. currentmodule:: mlx.core.metal
.. autosummary::
.. autosummary::
:toctree: _autosummary
is_available
@@ -12,3 +12,5 @@ Metal
get_cache_memory
set_memory_limit
set_cache_limit
start_capture
stop_capture

View File

@@ -30,6 +30,7 @@ Module
Module.named_modules
Module.parameters
Module.save_weights
Module.set_dtype
Module.train
Module.trainable_parameters
Module.unfreeze

View File

@@ -5,13 +5,13 @@ Operations
.. currentmodule:: mlx.core
.. autosummary::
.. autosummary::
:toctree: _autosummary
abs
add
all
allclose
allclose
any
arange
arccos
@@ -51,6 +51,7 @@ Operations
erf
erfinv
exp
expm1
expand_dims
eye
flatten
@@ -62,10 +63,10 @@ Operations
identity
inner
isclose
isnan
isposinf
isneginf
isinf
isnan
isneginf
isposinf
less
less_equal
linspace
@@ -83,6 +84,7 @@ Operations
max
maximum
mean
meshgrid
min
minimum
moveaxis
@@ -117,6 +119,7 @@ Operations
square
squeeze
stack
std
stop_gradient
subtract
sum

View File

@@ -38,6 +38,7 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
gumbel
key
normal
multivariate_normal
randint
seed
split

View File

@@ -49,7 +49,7 @@ it will be added. You can load the array with:
.. code-block:: shell
>>> mx.load("array.npy", a)
>>> mx.load("array.npy")
array([1], dtype=float32)
Here's an example of saving several arrays to a single file:

View File

@@ -8,3 +8,4 @@ endfunction(build_example)
build_example(tutorial.cpp)
build_example(linear_regression.cpp)
build_example(logistic_regression.cpp)
build_example(metal_capture.cpp)

View File

@@ -0,0 +1,31 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/mlx.h"
using namespace mlx::core;
int main() {
// To use Metal debugging and profiling:
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
// 2. Run with MTL_CAPTURE_ENABLED=1.
assert(metal::start_capture("mlx_trace.gputrace"));
// Start at index two because the default GPU and CPU streams have indices
// zero and one, respectively. This naming matches the label assigned to each
// stream's command queue.
auto s2 = new_stream(Device::gpu);
auto s3 = new_stream(Device::gpu);
auto a = arange(1.f, 10.f, 1.f, float32, s2);
auto b = arange(1.f, 10.f, 1.f, float32, s3);
auto x = add(a, a, s2);
auto y = add(b, b, s3);
// The multiply will happen on the default stream.
std::cout << multiply(x, y) << std::endl;
metal::stop_capture();
}

View File

@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.27)
project(mlx_sample_extensions LANGUAGES CXX)
project(_ext LANGUAGES CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_CXX_STANDARD 17)
@@ -11,8 +11,12 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED)
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
# ----------------------------- Extensions -----------------------------
@@ -38,7 +42,6 @@ target_link_libraries(mlx_ext PUBLIC mlx)
# Build metallib
if(MLX_BUILD_METAL)
mlx_build_metallib(
TARGET mlx_ext_metallib
TITLE mlx_ext
@@ -54,13 +57,15 @@ if(MLX_BUILD_METAL)
endif()
# ----------------------------- Pybind -----------------------------
pybind11_add_module(
mlx_sample_extensions
# ----------------------------- Python Bindings -----------------------------
nanobind_add_module(
_ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
target_link_libraries(_ext PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif()

View File

@@ -0,0 +1,18 @@
## Build the extensions
```
pip install -e .
```
For faster builds during development, you can also pre-install the requirements:
```
pip install -r requirements.txt
```
And then run:
```
python setup.py build_ext -j8 --inplace
```

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <iostream>
@@ -43,7 +43,7 @@ array axpby(
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = is_floating_point(promoted_dtype)
auto out_dtype = issubdtype(promoted_dtype, float32)
? promoted_dtype
: promote_types(promoted_dtype, float32);
@@ -61,7 +61,7 @@ array axpby(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_unique<Axpby>(to_stream(s), alpha, beta),
std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
@@ -106,12 +106,12 @@ void axpby_impl(
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
std::vector<array>& out_arr) {
auto out = out_arr[0];
std::vector<array>& outputs) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
@@ -150,11 +150,7 @@ void axpby_impl_accelerate(
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
@@ -180,11 +176,11 @@ void axpby_impl_accelerate(
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outarr) {
auto out = outarr[0];
std::vector<array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
@@ -195,7 +191,7 @@ void Axpby::eval_cpu(
}
// Fall back to common backend if specializations are not available
eval(inputs, outarr);
eval(inputs, outputs);
}
#else // Accelerate not available
@@ -203,8 +199,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& out) {
eval(inputs, out);
const std::vector<array>& outputs) {
eval(inputs, outputs);
}
#endif
@@ -218,12 +214,12 @@ void Axpby::eval_cpu(
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outarr) {
std::vector<array>& outputs) {
// Prepare inputs
auto out = outarr[0];
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
@@ -372,4 +368,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -42,9 +42,9 @@ class Axpby : public Primitive {
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& out)
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& out)
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
/** The Jacobian-vector product. */
@@ -83,7 +83,7 @@ class Axpby : public Primitive {
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, std::vector<array>& out);
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
};
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -1,31 +1,31 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/variant.h>
#include "axpby/axpby.h"
namespace py = pybind11;
using namespace py::literals;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
PYBIND11_MODULE(mlx_sample_extensions, m) {
m.doc() = "Sample C++ and metal extensions for MLX";
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
py::pos_only(),
"alpha"_a,
"beta"_a,
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
nb::kw_only(),
"stream"_a = nb::none(),
R"(
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
@@ -37,5 +37,5 @@ PYBIND11_MODULE(mlx_sample_extensions, m) {
Returns:
array: ``alpha * x + beta * y``
)pbdoc");
}
)");
}

View File

@@ -1,3 +1,8 @@
[build-system]
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24", "mlx @ git+https://github.com/mlx-explore/mlx@main"]
build-backend = "setuptools.build_meta"
requires = [
"setuptools>=42",
"cmake>=3.24",
"mlx>=0.9.0",
"nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f",
]
build-backend = "setuptools.build_meta"

View File

@@ -0,0 +1,4 @@
setuptools>=42
cmake>=3.24
mlx>=0.9.0
nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
from setuptools import setup
@@ -9,11 +9,11 @@ if __name__ == "__main__":
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"],
package_dir={"": "."},
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev": []},
zip_safe=False,
python_requires=">=3.8",
)

View File

@@ -93,7 +93,9 @@ void array::detach() {
}
void array::eval() {
mlx::core::eval({*this});
if (!is_evaled()) {
mlx::core::eval({*this});
}
}
bool array::is_tracer() const {
@@ -190,6 +192,36 @@ array::ArrayDesc::ArrayDesc(
init();
}
array::ArrayDesc::~ArrayDesc() {
// When an array description is destroyed it will delete a bunch of arrays
// that may also destory their corresponding descriptions and so on and so
// forth.
//
// This calls recursively the destructor and can result in stack overflow, we
// instead put them in a vector and destroy them one at a time resulting in a
// max stack depth of 2.
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
for (array& a : inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
while (!for_deletion.empty()) {
// top is going to be deleted at the end of the block *after* the arrays
// with inputs have been moved into the vector
auto top = std::move(for_deletion.back());
for_deletion.pop_back();
for (array& a : top->inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
}
}
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
: arr(arr), idx(idx) {
if (arr.ndim() == 0) {

View File

@@ -256,6 +256,17 @@ class array {
array_desc_->position = position;
}
/** The i-th output of the array's primitive. */
const array& output(int i) const {
if (i == array_desc_->position) {
return *this;
} else if (i < array_desc_->position) {
return siblings()[i];
} else {
return siblings()[i + 1];
}
};
/** The outputs of the array's primitive (i.e. this array and
* its siblings) in the order the primitive expects. */
std::vector<array> outputs() const {
@@ -393,6 +404,8 @@ class array {
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
~ArrayDesc();
private:
// Initialize size, strides, and other metadata
void init();
@@ -510,4 +523,15 @@ void array::init(It src) {
}
}
/* Utilities for determining whether a template parameter is array. */
template <typename T>
inline constexpr bool is_array_v =
std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;
template <typename... T>
inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
template <typename... T>
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
} // namespace mlx::core

View File

@@ -301,7 +301,7 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) {
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else {
throw std::invalid_argument(
@@ -310,6 +310,19 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpm1f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
}
}
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -355,7 +368,7 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size();
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) {
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else {
throw std::invalid_argument(

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <limits>
@@ -201,7 +201,7 @@ struct NeonFp16SimdOps {
}
};
template <typename T, typename VT, typename Ops, int N>
template <typename T, typename AccT, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) {
Ops ops;
@@ -218,13 +218,21 @@ void softmax(const array& in, array& out) {
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
size_t s = M;
while (s >= N) {
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum);
VT vals;
if constexpr (std::is_same<T, AccT>::value) {
vals = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vals[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vmaximum = ops.max(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
T maximum = ops.reduce_max(vmaximum);
AccT maximum = ops.reduce_max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, *current_in_ptr);
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
@@ -234,18 +242,29 @@ void softmax(const array& in, array& out) {
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum));
ops.store(current_out_ptr, vexp);
*(VT*)current_out_ptr = vexp;
VT vexp;
if constexpr (std::is_same<T, AccT>::value) {
vexp = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vexp = ops.exp(ops.sub(vexp, maximum));
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, vexp);
}
vnormalizer = ops.add(vnormalizer, vexp);
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
T normalizer = ops.reduce_add(vnormalizer);
AccT normalizer = ops.reduce_add(vnormalizer);
while (s-- > 0) {
T _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = _exp;
AccT _exp = std::exp(*current_in_ptr - maximum);
if (std::is_same<T, AccT>::value) {
*current_out_ptr = _exp;
}
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
@@ -254,14 +273,33 @@ void softmax(const array& in, array& out) {
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
} else {
VT vexp;
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
for (int i = 0; i < N; ++i) {
current_out_ptr[i] = vexp[i];
}
current_in_ptr += N;
}
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
*current_out_ptr *= normalizer;
if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++;
}
current_out_ptr++;
}
}
@@ -308,15 +346,29 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
"Softmax is defined only for floating point types");
break;
case float32:
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>(
in, out);
softmax<
float,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
break;
case float16:
softmax<
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
if (precise_) {
softmax<
float16_t,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
} else {
softmax<
float16_t,
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
}
break;
case bfloat16:
eval(inputs, out);

View File

@@ -179,18 +179,16 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (is_floating_point(out.dtype())) {
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
} else {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
}
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[logaddexp] Cannot compute logaddexp for arrays with"

View File

@@ -126,4 +126,102 @@ std::string build_lib_name(
return os.str();
}
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape) {
bool contiguous = true;
bool all_contig = true;
bool all_row_contig = true;
bool all_col_contig = true;
int non_scalar_inputs = 0;
for (const auto& x : inputs) {
if (is_scalar(x)) {
continue;
}
non_scalar_inputs++;
bool shape_eq = x.shape() == shape;
all_contig &= (x.flags().contiguous && shape_eq);
all_row_contig &= (x.flags().row_contiguous && shape_eq);
all_col_contig &= (x.flags().col_contiguous && shape_eq);
}
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
contiguous = false;
} else if (non_scalar_inputs == 1 && !all_contig) {
contiguous = false;
} else if (non_scalar_inputs == 0 && !shape.empty()) {
contiguous = false;
}
return contiguous;
}
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers /* = false */) {
if (contiguous) {
int o = 0;
std::vector<size_t> strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Correct size
// - Not a scalar
// - Donatable
// - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o++].move_shared_buffer(in);
} else {
outputs[o++].copy_shared_buffer(in);
}
}
// Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) {
strides = in.strides();
flags = in.flags();
data_size = in.data_size();
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
}
} else {
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o].move_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
} else {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
}
o++;
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
}
} // namespace mlx::core

View File

@@ -53,4 +53,18 @@ inline bool is_scalar(const array& x) {
return x.ndim() == 0;
}
// Check if we can use a contiguous operation given inputs and the output shape
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape);
// Allocate space for the outputs possibly with input donation
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers = false);
} // namespace mlx::core

View File

@@ -52,8 +52,25 @@ void* compile(
return nullptr;
}
std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255
// characters. Clip file name with a little extra room and append a 16
// character hash.
constexpr int max_file_name_length = 245;
if (kernel_name.size() > max_file_name_length) {
std::ostringstream file_name;
file_name
<< std::string_view(kernel_name).substr(0, max_file_name_length - 16);
auto file_id = std::hash<std::string>{}(kernel_name);
file_name << "_" << std::hex << std::setw(16) << file_id << std::dec;
kernel_file_name = file_name.str();
} else {
kernel_file_name = kernel_name;
}
std::ostringstream shared_lib_name;
shared_lib_name << "lib" << kernel_name << ".so";
shared_lib_name << "lib" << kernel_file_name << ".so";
auto shared_lib_path = get_temp_file(shared_lib_name.str());
bool lib_exists = false;
{
@@ -64,7 +81,7 @@ void* compile(
if (!lib_exists) {
// Open source file and write source code to it
std::ostringstream source_file_name;
source_file_name << kernel_name << ".cpp";
source_file_name << kernel_file_name << ".cpp";
auto source_file_path = get_temp_file(source_file_name.str());
std::ofstream source_file(source_file_path);
@@ -248,28 +265,7 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
bool contiguous = true;
{
bool all_contig = true;
bool all_row_contig = true;
bool all_col_contig = true;
int non_scalar_inputs = 0;
for (auto& x : inputs) {
if (is_scalar(x)) {
continue;
}
non_scalar_inputs++;
bool shape_eq = x.shape() == shape;
all_contig &= (x.flags().contiguous && shape_eq);
all_row_contig &= (x.flags().row_contiguous && shape_eq);
all_col_contig &= (x.flags().col_contiguous && shape_eq);
}
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
contiguous = false;
} else if (non_scalar_inputs == 1 && !all_contig) {
contiguous = false;
}
}
bool contiguous = compiled_check_contiguity(inputs, shape);
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;
@@ -342,58 +338,8 @@ void Compiled::eval_cpu(
fn_ptr = compile(kernel_name, kernel.str());
}
// Allocate space for the outputs possibly with input donation
if (contiguous) {
int o = 0;
std::vector<size_t> strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in);
}
// Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) {
strides = in.strides();
flags = in.flags();
data_size = in.data_size();
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
}
} else {
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
o++;
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false);
for (auto& x : outputs) {
args.push_back(x.data<void>());

View File

@@ -272,7 +272,7 @@ inline void copy_general_general(const array& src, array& dst) {
}
template <typename SrcT, typename DstT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (ctype) {
case CopyType::Scalar:
copy_single<SrcT, DstT>(src, dst);
@@ -281,54 +281,54 @@ void copy(const array& src, array& dst, CopyType ctype, Args... args) {
copy_vector<SrcT, DstT>(src, dst);
return;
case CopyType::General:
copy_general<SrcT, DstT>(src, dst, args...);
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
return;
case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst, args...);
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
}
}
template <typename SrcT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (dst.dtype()) {
case bool_:
copy<SrcT, bool>(src, dst, ctype, args...);
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint8:
copy<SrcT, uint8_t>(src, dst, ctype, args...);
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint16:
copy<SrcT, uint16_t>(src, dst, ctype, args...);
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint32:
copy<SrcT, uint32_t>(src, dst, ctype, args...);
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint64:
copy<SrcT, uint64_t>(src, dst, ctype, args...);
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int8:
copy<SrcT, int8_t>(src, dst, ctype, args...);
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int16:
copy<SrcT, int16_t>(src, dst, ctype, args...);
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int32:
copy<SrcT, int32_t>(src, dst, ctype, args...);
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int64:
copy<SrcT, int64_t>(src, dst, ctype, args...);
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float16:
copy<SrcT, float16_t>(src, dst, ctype, args...);
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float32:
copy<SrcT, float>(src, dst, ctype, args...);
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype, args...);
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case complex64:
copy<SrcT, complex64_t>(src, dst, ctype, args...);
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
}
}
@@ -338,46 +338,46 @@ inline void copy_inplace_dispatch(
const array& src,
array& dst,
CopyType ctype,
Args... args) {
Args&&... args) {
switch (src.dtype()) {
case bool_:
copy<bool>(src, dst, ctype, args...);
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint8:
copy<uint8_t>(src, dst, ctype, args...);
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint16:
copy<uint16_t>(src, dst, ctype, args...);
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint32:
copy<uint32_t>(src, dst, ctype, args...);
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint64:
copy<uint64_t>(src, dst, ctype, args...);
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int8:
copy<int8_t>(src, dst, ctype, args...);
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int16:
copy<int16_t>(src, dst, ctype, args...);
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int32:
copy<int32_t>(src, dst, ctype, args...);
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int64:
copy<int64_t>(src, dst, ctype, args...);
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float16:
copy<float16_t>(src, dst, ctype, args...);
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float32:
copy<float>(src, dst, ctype, args...);
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype, args...);
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case complex64:
copy<complex64_t>(src, dst, ctype, args...);
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
}
}

View File

@@ -57,6 +57,7 @@ DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(Exp)
DEFAULT(Expm1)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Full)

View File

@@ -241,6 +241,13 @@ struct Exp {
}
};
struct Expm1 {
template <typename T>
T operator()(T x) {
return expm1(x);
};
};
struct Floor {
template <typename T>
T operator()(T x) {

View File

@@ -22,7 +22,7 @@ namespace mlx::core {
void Abs::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (is_unsigned(in.dtype())) {
if (issubdtype(in.dtype(), unsignedinteger)) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
@@ -37,7 +37,7 @@ void Arange::eval(const std::vector<array>& inputs, array& out) {
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcCos());
} else {
throw std::invalid_argument(
@@ -49,7 +49,7 @@ void ArcCos::eval(const std::vector<array>& inputs, array& out) {
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcCosh());
} else {
throw std::invalid_argument(
@@ -61,7 +61,7 @@ void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcSin());
} else {
throw std::invalid_argument(
@@ -73,7 +73,7 @@ void ArcSin::eval(const std::vector<array>& inputs, array& out) {
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcSinh());
} else {
throw std::invalid_argument(
@@ -85,7 +85,7 @@ void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcTan());
} else {
throw std::invalid_argument(
@@ -97,7 +97,7 @@ void ArcTan::eval(const std::vector<array>& inputs, array& out) {
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcTanh());
} else {
throw std::invalid_argument(
@@ -171,7 +171,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
void Ceil::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (not is_integral(in.dtype())) {
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil());
} else {
// No-op integer types
@@ -211,7 +211,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Cos());
} else {
throw std::invalid_argument(
@@ -223,7 +223,7 @@ void Cos::eval(const std::vector<array>& inputs, array& out) {
void Cosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Cosh());
} else {
throw std::invalid_argument(
@@ -350,7 +350,7 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
void Exp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Exp());
} else {
throw std::invalid_argument(
@@ -359,10 +359,22 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
}
}
void Expm1::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Expm1());
} else {
throw std::invalid_argument(
"[expm1] Cannot exponentiate elements in array"
" with non floating point type.");
}
}
void Floor::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (not is_integral(in.dtype())) {
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor());
} else {
// No-op integer types
@@ -388,7 +400,7 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
switch (base_) {
case Base::e:
unary_fp(in, out, detail::Log());
@@ -410,7 +422,7 @@ void Log::eval(const std::vector<array>& inputs, array& out) {
void Log1p::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Log1p());
} else {
throw std::invalid_argument(
@@ -597,7 +609,7 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
void Round::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (not is_integral(in.dtype())) {
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round());
} else {
// No-op integer types
@@ -608,7 +620,7 @@ void Round::eval(const std::vector<array>& inputs, array& out) {
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sigmoid());
} else {
throw std::invalid_argument(
@@ -630,7 +642,7 @@ void Sign::eval(const std::vector<array>& inputs, array& out) {
void Sin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sin());
} else {
throw std::invalid_argument(
@@ -642,7 +654,7 @@ void Sin::eval(const std::vector<array>& inputs, array& out) {
void Sinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sinh());
} else {
throw std::invalid_argument(
@@ -850,7 +862,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
void Tan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Tan());
} else {
throw std::invalid_argument(
@@ -862,7 +874,7 @@ void Tan::eval(const std::vector<array>& inputs, array& out) {
void Tanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Tanh());
} else {
throw std::invalid_argument(

View File

@@ -6,8 +6,6 @@
namespace mlx::core {
namespace {
enum ReductionOpType {
// Self-explanatory. Read everything and produce 1 output.
ContiguousAllReduce,
@@ -38,6 +36,21 @@ enum ReductionOpType {
GeneralReduce
};
struct ReductionPlan {
ReductionOpType type;
std::vector<int> shape;
std::vector<size_t> strides;
ReductionPlan(
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
namespace {
// Helper for the ndimensional strided loop
// Should this be in utils?
inline void nd_loop(
@@ -110,19 +123,6 @@ struct DefaultContiguousReduce {
}
};
struct ReductionPlan {
ReductionOpType type;
std::vector<int> shape;
std::vector<size_t> strides;
ReductionPlan(
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&

View File

@@ -222,7 +222,7 @@ void scan_dispatch(
}
case Scan::Min: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
auto init = (is_floating_point(input.dtype()))
auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
@@ -232,7 +232,7 @@ void scan_dispatch(
}
case Scan::Max: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
auto init = (is_floating_point(input.dtype()))
auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
@@ -10,7 +10,7 @@ namespace mlx::core {
namespace {
template <typename T>
template <typename T, typename AccT>
void softmax(const array& in, array& out) {
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
@@ -22,26 +22,36 @@ void softmax(const array& in, array& out) {
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
// Find the maximum
current_in_ptr = in_ptr;
T maximum = *current_in_ptr;
AccT maximum = *current_in_ptr;
for (int j = 0; j < N; j++, current_in_ptr++) {
maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum;
maximum = (maximum < *current_in_ptr) ? static_cast<AccT>(*current_in_ptr)
: maximum;
}
// Compute the normalizer and the exponentials
T normalizer = 0;
AccT normalizer = 0;
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
T expv = std::exp(*current_in_ptr - maximum);
AccT expv = std::exp(*current_in_ptr - maximum);
normalizer += expv;
*current_out_ptr = expv;
if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr = expv;
}
}
normalizer = 1 / normalizer;
// Normalize
current_in_ptr = in_ptr;
current_out_ptr = out_ptr;
for (int j = 0; j < N; j++, current_out_ptr++) {
*current_out_ptr *= normalizer;
if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr *= normalizer;
} else {
auto v = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(v * normalizer);
current_in_ptr++;
}
}
}
}
@@ -91,13 +101,21 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
"Softmax is defined only for floating point types");
break;
case float32:
softmax<float>(in, out);
softmax<float, float>(in, out);
break;
case float16:
softmax<float16_t>(in, out);
if (precise_) {
softmax<float16_t, float>(in, out);
} else {
softmax<float16_t, float16_t>(in, out);
}
break;
case bfloat16:
softmax<bfloat16_t>(in, out);
if (precise_) {
softmax<bfloat16_t, float>(in, out);
} else {
softmax<bfloat16_t, bfloat16_t>(in, out);
}
break;
case complex64:
throw std::invalid_argument(

View File

@@ -89,9 +89,8 @@ collapse_contiguous_dims(const std::vector<array>& xs) {
return collapse_contiguous_dims(xs[0].shape(), strides);
}
template <typename... Arrays>
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(Arrays... xs) {
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
inline auto collapse_contiguous_dims(Arrays&&... xs) {
return collapse_contiguous_dims(
std::vector<array>{std::forward<Arrays>(xs)...});
}

View File

@@ -1,6 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include <mach/vm_page_size.h>
#include <unistd.h>

View File

@@ -229,14 +229,7 @@ void Compiled::eval_gpu(
// Figure out which kernel we are using
auto& output_shape = outputs[0].shape();
bool contiguous = true;
for (auto& x : inputs) {
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
!is_scalar(x)) {
contiguous = false;
break;
}
}
bool contiguous = compiled_check_contiguity(inputs, output_shape);
// Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
@@ -296,7 +289,7 @@ void Compiled::eval_gpu(
}
}
auto kernel = d.get_kernel(kernel_name, lib);
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Put the inputs in
@@ -307,7 +300,7 @@ void Compiled::eval_gpu(
continue;
}
auto& x = inputs[i];
set_array_buffer(compute_encoder, x, cnt++);
compute_encoder.set_input_array(x, cnt++);
if (!contiguous && !is_scalar(x)) {
compute_encoder->setBytes(
strides[stride_idx].data(),
@@ -317,32 +310,12 @@ void Compiled::eval_gpu(
}
}
// Allocate space for the outputs possibly with input donation
{
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o].move_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
o++;
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, true);
// Put the outputs in
for (auto& x : outputs) {
set_array_buffer(compute_encoder, x, cnt++);
compute_encoder.set_output_array(x, cnt++);
}
// Put the output shape and strides in

View File

@@ -41,12 +41,12 @@ void explicit_gemm_conv_ND_gpu(
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, in_unfolded, 1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
@@ -140,7 +140,7 @@ void slow_conv_2D_gpu(
<< "_tm" << tm << "_tn" << tn;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -153,9 +153,9 @@ void slow_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(bm, bn, 1);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
@@ -241,7 +241,7 @@ void implicit_gemm_conv_2D_gpu(
<< "_filter_" << (small_filter ? 's' : 'l');
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -254,9 +254,9 @@ void implicit_gemm_conv_2D_gpu(
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
// Encode arrays
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
@@ -394,7 +394,7 @@ void implicit_gemm_conv_2D_general_gpu(
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -408,9 +408,9 @@ void implicit_gemm_conv_2D_general_gpu(
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
// Encode arrays
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
@@ -511,12 +511,12 @@ void winograd_conv_2D_gpu(
std::ostringstream kname;
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
<< bc;
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, wt, 0);
set_array_buffer(compute_encoder, filt_wg, 1);
compute_encoder.set_input_array(wt, 0);
compute_encoder.set_output_array(filt_wg, 1);
compute_encoder->setBytes(&C_c, sizeof(int), 2);
compute_encoder->setBytes(&O_c, sizeof(int), 3);
@@ -539,12 +539,12 @@ void winograd_conv_2D_gpu(
std::ostringstream kname;
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
<< bc;
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in_padded, 0);
set_array_buffer(compute_encoder, inp_wg, 1);
compute_encoder.set_input_array(in_padded, 0);
compute_encoder.set_output_array(inp_wg, 1);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
@@ -587,12 +587,12 @@ void winograd_conv_2D_gpu(
std::ostringstream kname;
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
<< bc;
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, out_wg, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(out_wg, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);

View File

@@ -83,15 +83,15 @@ void copy_gpu_inplace(
kname << "_" << shape.size();
}
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_in = in.data_shared_ptr() == nullptr;
inp_offset *= size_of(in.dtype());
out_offset *= size_of(out.dtype());
set_array_buffer(compute_encoder, donate_in ? out : in, inp_offset, 0);
set_array_buffer(compute_encoder, out, out_offset, 1);
compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset);
compute_encoder.set_output_array(out, 1, out_offset);
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
int ndim = shape.size();

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-24 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include <cstdlib>
@@ -11,7 +11,9 @@
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
namespace fs = std::filesystem;
@@ -145,6 +147,7 @@ void Device::new_queue(int index) {
// We lock this as a critical section for safety
const std::lock_guard<std::mutex> lock(mtx_);
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
debug_set_stream_queue_label(q, index);
if (!q) {
throw std::runtime_error(
"[metal::Device] Failed to make new command queue.");
@@ -203,14 +206,15 @@ void Device::end_encoding(int index) {
}
}
MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
CommandEncoder& Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index);
if (eit == encoder_map_.end()) {
auto cb = get_command_buffer(index);
auto compute_encoder = cb->computeCommandEncoder();
auto compute_encoder =
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
// Increment ref count so the buffer is not garbage collected
compute_encoder->retain();
eit = encoder_map_.insert({index, compute_encoder}).first;
eit = encoder_map_.emplace(index, CommandEncoder{compute_encoder}).first;
}
return eit->second;
}

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-24 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
@@ -7,10 +7,12 @@
#include <mutex>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <dlfcn.h>
#include <filesystem>
#include "mlx/array.h"
#include "mlx/device.h"
namespace fs = std::filesystem;
@@ -34,6 +36,69 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
struct CommandEncoder {
CommandEncoder(MTL::ComputeCommandEncoder* enc)
: enc(enc), concurrent(false){};
CommandEncoder& operator=(const CommandEncoder&) = delete;
struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc) : enc(enc) {
enc.concurrent = true;
}
~ConcurrentContext() {
enc.concurrent = false;
enc.outputs.insert(
enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
enc.concurrent_outputs.clear();
}
private:
CommandEncoder& enc;
};
MTL::ComputeCommandEncoder* operator->() {
return enc;
}
void set_input_array(const array& a, int idx, int offset = 0) {
auto r_buf =
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
void set_output_array(array& a, int idx, int offset = 0) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) {
concurrent_outputs.insert(buf);
} else {
outputs.insert(buf);
}
}
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
}
private:
MTL::ComputeCommandEncoder* enc;
bool concurrent;
std::unordered_set<MTL::Resource*> outputs;
std::unordered_set<MTL::Resource*> concurrent_outputs;
};
class Device {
public:
Device();
@@ -51,7 +116,7 @@ class Device {
int get_command_buffer_ops(int index);
void increment_command_buffer_ops(int index);
void commit_command_buffer(int index);
MTL::ComputeCommandEncoder* get_command_encoder(int index);
CommandEncoder& get_command_encoder(int index);
void end_encoding(int index);
void register_library(
@@ -132,7 +197,7 @@ class Device {
MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
std::unordered_map<int32_t, MTL::ComputeCommandEncoder*> encoder_map_;
std::unordered_map<int32_t, CommandEncoder> encoder_map_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
std::unordered_map<std::string, MTL::Library*> library_map_;
std::mutex mtx_;

View File

@@ -49,7 +49,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
kname << "_" << idx_ndim;
}
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -81,8 +81,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
}
// Set all the buffers
set_array_buffer(compute_encoder, src, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(src, 0);
compute_encoder.set_output_array(out, 1);
// Set source info
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
@@ -103,7 +103,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i);
compute_encoder.set_input_array(inputs[i], 20 + i);
}
// Launch grid
@@ -183,7 +183,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
}
kname << "_" << nidx;
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto& upd = inputs.back();
@@ -192,8 +192,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setComputePipelineState(kernel);
// Set all the buffers
set_array_buffer(compute_encoder, upd, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder.set_input_array(upd, 1);
compute_encoder.set_output_array(out, 2);
// Set update info
uint upd_ndim = upd.ndim();
@@ -210,7 +210,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i);
compute_encoder.set_input_array(inputs[i], 20 + i);
}
// Launch grid
@@ -280,7 +280,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i);
compute_encoder.set_input_array(inputs[i], 20 + i);
}
// Launch grid

View File

@@ -7,6 +7,7 @@ set(
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
@@ -37,11 +38,17 @@ set(
)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS}
-gline-tables-only
-frecord-sources)
endif()
add_custom_command(
COMMAND xcrun -sdk macosx metal -Wall -Wextra
-fno-fast-math
-c ${SRCFILE}
-I${PROJECT_SOURCE_DIR}
COMMAND xcrun -sdk macosx metal
${METAL_FLAGS}
-c ${SRCFILE}
-I${PROJECT_SOURCE_DIR}
-o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS}
OUTPUT ${TARGET}.air

View File

@@ -0,0 +1,89 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_math>
// Original license copied below:
// Copyright (c) 2015-2023 Norbert Juffa
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/* Compute exponential base e minus 1. Maximum ulp error = 0.997458
i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1.
Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5).
With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy,
when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r.
NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2)
*/
float expm1f_scaled_unchecked(float a, float b) {
float f, j, r, s, t, u, v, x, y;
int i;
// exp(a) = 2**i * exp(f); i = rintf (a / log(2))
j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23
j = j - 12582912.0f; // 0x1.8p23
i = (int)j;
f = fma(j, -6.93145752e-1f, a);
// approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
s = f * f;
if (a == 0.0f)
s = a; // ensure -0 is passed through
// err = 0.997458 ulp1 = 11081805
r = 1.97350979e-4f; // 0x1.9de000p-13
r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10
r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7
r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5
r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3
r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2
u = (j == 1) ? (f + 0.5f) : f;
v = fma(r, s, u);
s = 0.5f * b;
t = ldexp(s, i);
y = t - s;
x = (t - y) - s; // double-float canonicalization of difference
r = fma(v, t, x) + y;
r = r + r;
if (j == 0)
r = v;
if (j == 1)
r = v + v;
return r;
}
/* Compute exponential base e minus 1. max ulp err = 0.99746 */
float expm1f(float a) {
float r;
r = expm1f_scaled_unchecked(a, 1.0f);
/* handle severe overflow and underflow */
if (abs(a - 1.0f) > 88.0f) {
r = fma(r, r, -1.0f);
}
return r;
}

View File

@@ -205,39 +205,341 @@ template <typename T, int N_READS = RMS_N_READS>
}
}
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void vjp_layer_norm_single_row(
const device T* x,
const device T* w,
const device T* g,
device T* gx,
device T* gw,
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers
x += gid * axis_size + lid * N_READS;
g += gid * axis_size + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the computation and accumulators
float thread_x[N_READS];
float thread_w[N_READS];
float thread_g[N_READS];
float sumx = 0;
float sumx2 = 0;
float sumwg = 0;
float sumwgx = 0;
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx[SIMD_SIZE];
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_sumwg[SIMD_SIZE];
threadgroup float local_sumwgx[SIMD_SIZE];
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
threadgroup float local_meanwg[1];
threadgroup float local_meanwgx[1];
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = x[i];
thread_w[i] = w[i * w_stride];
thread_g[i] = g[i];
float wg = thread_w[i] * thread_g[i];
sumx += thread_x[i];
sumx2 += thread_x[i] * thread_x[i];
sumwg += wg;
sumwgx += wg * thread_x[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = x[i];
thread_w[i] = w[i * w_stride];
thread_g[i] = g[i];
float wg = thread_w[i] * thread_g[i];
sumx += thread_x[i];
sumx2 += thread_x[i] * thread_x[i];
sumwg += wg;
sumwgx += wg * thread_x[i];
}
}
}
sumx = simd_sum(sumx);
sumx2 = simd_sum(sumx2);
sumwg = simd_sum(sumwg);
sumwgx = simd_sum(sumwgx);
// Initialize shared memory
if (simd_group_id == 0) {
local_sumx[simd_lane_id] = 0;
local_sumx2[simd_lane_id] = 0;
local_sumwg[simd_lane_id] = 0;
local_sumwgx[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sumx[simd_group_id] = sumx;
local_sumx2[simd_group_id] = sumx2;
local_sumwg[simd_group_id] = sumwg;
local_sumwgx[simd_group_id] = sumwgx;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
sumx = simd_sum(local_sumx[simd_lane_id]);
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
sumwg = simd_sum(local_sumwg[simd_lane_id]);
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
if (simd_lane_id == 0) {
float mean = sumx / axis_size;
float variance = sumx2 / axis_size - mean * mean;
local_mean[0] = mean;
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
local_meanwg[0] = sumwg / axis_size;
local_meanwgx[0] = sumwgx / axis_size;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = local_mean[0];
float normalizer = local_normalizer[0];
float meanwg = local_meanwg[0];
float meanwgxc = local_meanwgx[0] - meanwg * mean;
float normalizer2 = normalizer * normalizer;
// Write the outputs
gx += gid * axis_size + lid * N_READS;
gw += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
}
}
}
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void vjp_layer_norm_looped(
const device T* x,
const device T* w,
const device T* g,
device T* gx,
device T* gw,
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers
x += gid * axis_size + lid * N_READS;
g += gid * axis_size + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the accumulators
float sumx = 0;
float sumx2 = 0;
float sumwg = 0;
float sumwgx = 0;
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx[SIMD_SIZE];
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_sumwg[SIMD_SIZE];
threadgroup float local_sumwgx[SIMD_SIZE];
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
threadgroup float local_meanwg[1];
threadgroup float local_meanwgx[1];
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = x[i + r];
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
float wg = wi * gi;
sumx += xi;
sumx2 += xi * xi;
sumwg += wg;
sumwgx += wg * xi;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = x[i + r];
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
float wg = wi * gi;
sumx += xi;
sumx2 += xi * xi;
sumwg += wg;
sumwgx += wg * xi;
}
}
}
}
sumx = simd_sum(sumx);
sumx2 = simd_sum(sumx2);
sumwg = simd_sum(sumwg);
sumwgx = simd_sum(sumwgx);
// Initialize shared memory
if (simd_group_id == 0) {
local_sumx[simd_lane_id] = 0;
local_sumx2[simd_lane_id] = 0;
local_sumwg[simd_lane_id] = 0;
local_sumwgx[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sumx[simd_group_id] = sumx;
local_sumx2[simd_group_id] = sumx2;
local_sumwg[simd_group_id] = sumwg;
local_sumwgx[simd_group_id] = sumwgx;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
sumx = simd_sum(local_sumx[simd_lane_id]);
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
sumwg = simd_sum(local_sumwg[simd_lane_id]);
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
if (simd_lane_id == 0) {
float mean = sumx / axis_size;
float variance = sumx2 / axis_size - mean * mean;
local_mean[0] = mean;
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
local_meanwg[0] = sumwg / axis_size;
local_meanwgx[0] = sumwgx / axis_size;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = local_mean[0];
float normalizer = local_normalizer[0];
float meanwg = local_meanwg[0];
float meanwgxc = local_meanwgx[0] - meanwg * mean;
float normalizer2 = normalizer * normalizer;
// Write the outputs
gx += gid * axis_size + lid * N_READS;
gw += gid * axis_size + lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = (x[i + r] - mean) * normalizer;
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
xi * meanwgxc * normalizer2);
gw[i + r] = static_cast<T>(gi * xi);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = (x[i + r] - mean) * normalizer;
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
xi * meanwgxc * normalizer2);
gw[i + r] = static_cast<T>(gi * xi);
}
}
}
}
}
// clang-format off
#define instantiate_layer_norm_single_row(name, itype) \
template [[host_name("layer_norm" #name)]] [[kernel]] void \
layer_norm_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
#define instantiate_layer_norm_single_row(name, itype) \
template [[host_name("layer_norm" #name)]] [[kernel]] void \
layer_norm_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \
vjp_layer_norm_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm_looped(name, itype) \
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
#define instantiate_layer_norm_looped(name, itype) \
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \
vjp_layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gb, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm(name, itype) \

View File

@@ -15,14 +15,6 @@ using namespace metal;
MLX_MTL_CONST int SIMD_SIZE = 32;
template <typename T> struct AccT {
typedef T acc_t;
};
template <> struct AccT<bfloat16_t> {
typedef float acc_t;
};
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector(const device T *x, thread U *x_thread) {
@@ -60,6 +52,51 @@ inline U load_vector(const device T *x, thread U *x_thread) {
return sum;
}
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
U sum = 0;
if (bits == 2) {
for (int i = 0; i < N; i += 4) {
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
x_thread[i] = x[i];
x_thread[i+1] = x[i+1] / 4.0f;
x_thread[i+2] = x[i+2] / 16.0f;
x_thread[i+3] = x[i+3] / 64.0f;
}
for (int i=N; i<values_per_thread; i++) {
x_thread[i] = 0;
}
}
else if (bits == 4) {
for (int i = 0; i < N; i += 4) {
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
x_thread[i] = x[i];
x_thread[i+1] = x[i+1] / 16.0f;
x_thread[i+2] = x[i+2] / 256.0f;
x_thread[i+3] = x[i+3] / 4096.0f;
}
for (int i=N; i<values_per_thread; i++) {
x_thread[i] = 0;
}
}
else if (bits == 8) {
for (int i = 0; i < N; i++) {
sum += x[i];
x_thread[i] = x[i];
}
for (int i=N; i<values_per_thread; i++) {
x_thread[i] = 0;
}
}
return sum;
}
template <typename U, int values_per_thread, int bits>
inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum) {
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
@@ -96,6 +133,74 @@ inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias
return scale * accum + sum * bias;
}
template <typename U, int values_per_thread, int bits>
inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum, int N) {
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
U accum = 0;
if (bits == 2) {
for (int i = 0; i < (N / 4); i++) {
accum += (
x_thread[4*i] * (w[i] & 0x03)
+ x_thread[4*i+1] * (w[i] & 0x0c)
+ x_thread[4*i+2] * (w[i] & 0x30)
+ x_thread[4*i+3] * (w[i] & 0xc0));
}
}
else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (N / 4); i++) {
accum += (
x_thread[4*i] * (ws[i] & 0x000f)
+ x_thread[4*i+1] * (ws[i] & 0x00f0)
+ x_thread[4*i+2] * (ws[i] & 0x0f00)
+ x_thread[4*i+3] * (ws[i] & 0xf000));
}
}
else if (bits == 8) {
for (int i = 0; i < N; i++) {
accum += x_thread[i] * w[i];
}
}
return scale * accum + sum * bias;
}
template <typename U, int values_per_thread, int bits>
inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
if (bits == 2) {
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
for (int i = 0; i < (values_per_thread / 4); i++) {
result[4*i] += x * (s[0] * (w[i] & 0x03) + bias);
result[4*i+1] += x * (s[1] * (w[i] & 0x0c) + bias);
result[4*i+2] += x * (s[2] * (w[i] & 0x30) + bias);
result[4*i+3] += x * (s[3] * (w[i] & 0xc0) + bias);
}
}
else if (bits == 4) {
const thread uint16_t* ws = (const thread uint16_t*)w;
U s[4] = {scale, scale / 16.0f, scale / 256.0f, scale / 4096.0f};
for (int i = 0; i < (values_per_thread / 4); i++) {
result[4*i] += x * (s[0] * (ws[i] & 0x000f) + bias);
result[4*i+1] += x * (s[1] * (ws[i] & 0x00f0) + bias);
result[4*i+2] += x * (s[2] * (ws[i] & 0x0f00) + bias);
result[4*i+3] += x * (s[3] * (ws[i] & 0xf000) + bias);
}
}
else if (bits == 8) {
for (int i = 0; i < values_per_thread; i++) {
result[i] += x * (scale * w[i] + bias);
}
}
}
template <typename T, int group_size, int bits, int packs_per_thread>
[[kernel]] void qmv_fast(
const device uint32_t* w [[buffer(0)]],
@@ -204,7 +309,8 @@ template <typename T, const int group_size, const int bits>
x += tid.z * in_vec_size + simd_lid * values_per_thread;
y += tid.z * out_vec_size + out_row;
for (int k = 0; k < in_vec_size; k += block_size) {
int k = 0;
for (; k < in_vec_size-block_size; k += block_size) {
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; out_row + row < out_vec_size; row++) {
@@ -222,6 +328,18 @@ template <typename T, const int group_size, const int bits>
biases += block_size / group_size;
x += block_size;
}
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
for (int row = 0; out_row + row < out_vec_size; row++) {
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
for (int row = 0; out_row + row < out_vec_size; row++) {
result[row] = simd_sum(result[row]);
@@ -239,7 +357,8 @@ template <typename T, const int group_size, const int bits>
x += tid.z * in_vec_size + simd_lid * values_per_thread;
y += tid.z * out_vec_size + used_out_row;
for (int k = 0; k < in_vec_size; k += block_size) {
int k = 0;
for (; k < in_vec_size-block_size; k += block_size) {
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) {
@@ -257,6 +376,18 @@ template <typename T, const int group_size, const int bits>
biases += block_size / group_size;
x += block_size;
}
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
result[row] += qdot_safe<U, values_per_thread, bits>(wl, x_thread, s, b, sum, remaining);
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
@@ -268,7 +399,7 @@ template <typename T, const int group_size, const int bits>
}
template <typename T, const int BM, const int BN, const int group_size, const int bits>
template <typename T, const int group_size, const int bits>
[[kernel]] void qvm(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
@@ -278,39 +409,28 @@ template <typename T, const int BM, const int BN, const int group_size, const in
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BM == SIMD_SIZE, "qvm expects BM to be equal to SIMD_SIZE");
static_assert(BN == BM, "qvm expects a block size of 32x32");
constexpr int num_simdgroups = 8;
constexpr int pack_factor = 32 / bits;
constexpr int blocksize = SIMD_SIZE;
(void)lid;
constexpr int bitmask = (1 << bits) - 1;
constexpr int el_per_int = 32 / bits;
constexpr int colgroup = BN * el_per_int;
constexpr int groups_per_block = colgroup / group_size;
typedef typename AccT<T>::acc_t U;
threadgroup U scales_block[BM * groups_per_block];
threadgroup U biases_block[BM * groups_per_block];
threadgroup U x_block[BM];
typedef float U;
thread uint32_t w_local;
thread U result[el_per_int] = {0};
thread U result[pack_factor] = {0};
thread U scale = 1;
thread U bias = 0;
thread U x_local = 0;
// Adjust positions
const int out_vec_size_w = out_vec_size / el_per_int;
const int out_vec_size_w = out_vec_size / pack_factor;
const int out_vec_size_g = out_vec_size / group_size;
int out_col_start = tid.y * (BN * el_per_int);
int out_col = out_col_start + simd_gid * el_per_int;
w += out_col / el_per_int;
scales += out_col_start / group_size;
biases += out_col_start / group_size;
int out_col = tid.y * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
w += out_col / pack_factor;
scales += out_col / group_size;
biases += out_col / group_size;
x += tid.z * in_vec_size;
y += tid.z * out_vec_size + out_col;
@@ -318,53 +438,39 @@ template <typename T, const int BM, const int BN, const int group_size, const in
return;
}
// Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=BM) {
int offset_lid = simd_lid + i;
int offset_gid = simd_gid + i;
bool thread_in_bounds = offset_lid < in_vec_size;
bool group_in_bounds = offset_gid < in_vec_size;
// Loop over in_vec in blocks of blocksize
int i = 0;
for (; i + blocksize <= in_vec_size; i += blocksize) {
x_local = x[i + simd_lid];
scale = scales[(i + simd_lid) * out_vec_size_g];
bias = biases[(i + simd_lid) * out_vec_size_g];
w_local = w[(i + simd_lid) * out_vec_size_w];
// Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid == 0) {
x_block[simd_lid] = (thread_in_bounds) ? x[offset_lid] : 0;
}
// Load the scales and biases to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_lid < groups_per_block && group_in_bounds) {
scales_block[simd_gid * groups_per_block + simd_lid] = scales[offset_gid * out_vec_size_g + simd_lid];
biases_block[simd_gid * groups_per_block + simd_lid] = biases[offset_gid * out_vec_size_g + simd_lid];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load in_vec, scale, bias to registers
x_local = x_block[simd_lid];
scale = scales_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
// Load the matrix elements
w_local = (thread_in_bounds) ? w[offset_lid * out_vec_size_w] : 0;
// Do all the work.
#pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) {
result[k] += (scale * static_cast<U>(w_local & bitmask) + bias) * x_local;
w_local >>= bits;
}
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
}
if (static_cast<int>(i + simd_lid) < in_vec_size) {
x_local = x[i + simd_lid];
scale = scales[(i + simd_lid) * out_vec_size_g];
bias = biases[(i + simd_lid) * out_vec_size_g];
w_local = w[(i + simd_lid) * out_vec_size_w];
} else {
x_local = 0;
scale = 0;
bias = 0;
w_local = 0;
}
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
// Accumulate in the simdgroup
#pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) {
for (int k=0; k<pack_factor; k++) {
result[k] = simd_sum(result[k]);
}
// Store the result
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) {
for (int k=0; k<pack_factor; k++) {
y[k] = static_cast<T>(result[k]);
}
}
@@ -414,6 +520,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
const int K_g = K / group_size;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
x += y_row * K;
w += y_col * K_w;
scales += y_col * K_g;
@@ -466,7 +573,10 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
if (y_row + offset_row < N) {
// y_col corresponds to the row of the weight matrix and added to
// offset_row it should be less than the total number of rows
// otherwise skip.
if (y_col + offset_row < N) {
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
@@ -619,7 +729,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
if (y_row + offset_row < K) {
if (k + offset_row < K) {
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
@@ -738,7 +848,7 @@ instantiate_qmv_types( 32, 8)
#define instantiate_qvm(name, itype, group_size, bits) \
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qvm<itype, 32, 32, group_size, bits>( \
[[kernel]] void qvm<itype, group_size, bits>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
@@ -747,7 +857,6 @@ instantiate_qmv_types( 32, 8)
const constant int& in_vec_size [[buffer(5)]], \
const constant int& out_vec_size [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);

View File

@@ -150,6 +150,216 @@ template <typename T, int N_READS = RMS_N_READS>
}
}
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void vjp_rms_single_row(
const device T* x,
const device T* w,
const device T* g,
device T* gx,
device T* gw,
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers
x += gid * axis_size + lid * N_READS;
g += gid * axis_size + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the computation and accumulators
float thread_x[N_READS];
float thread_w[N_READS];
float thread_g[N_READS];
float sumx2 = 0;
float sumgwx = 0;
// Allocate shared memory to implement the reduction
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_sumgwx[SIMD_SIZE];
threadgroup float local_normalizer[1];
threadgroup float local_meangwx[1];
// Read and accumulate locally
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = x[i];
thread_w[i] = w[w_stride * i];
thread_g[i] = g[i];
sumx2 += thread_x[i] * thread_x[i];
sumgwx += thread_x[i] * thread_w[i] * thread_g[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = x[i];
thread_w[i] = w[w_stride * i];
thread_g[i] = g[i];
sumx2 += thread_x[i] * thread_x[i];
sumgwx += thread_x[i] * thread_w[i] * thread_g[i];
}
}
}
// Accumulate across threads
sumx2 = simd_sum(sumx2);
sumgwx = simd_sum(sumgwx);
if (simd_group_id == 0) {
local_sumx2[simd_lane_id] = 0;
local_sumgwx[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_lane_id == 0) {
local_sumx2[simd_group_id] = sumx2;
local_sumgwx[simd_group_id] = sumgwx;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
sumgwx = simd_sum(local_sumgwx[simd_lane_id]);
if (simd_lane_id == 0) {
local_meangwx[0] = sumgwx / axis_size;
local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float meangwx = local_meangwx[0];
float normalizer = local_normalizer[0];
float normalizer3 = normalizer * normalizer * normalizer;
// Write the outputs
gx += gid * axis_size + lid * N_READS;
gw += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
}
}
}
}
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void vjp_rms_looped(
const device T* x,
const device T* w,
const device T* g,
device T* gx,
device T* gw,
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers
x += gid * axis_size + lid * N_READS;
g += gid * axis_size + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the accumulators
float sumx2 = 0;
float sumgwx = 0;
// Allocate shared memory to implement the reduction
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_sumgwx[SIMD_SIZE];
threadgroup float local_normalizer[1];
threadgroup float local_meangwx[1];
// Read and accumulate locally
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = x[i + r];
float wi = w[w_stride * (i + r)];
float gi = g[i + r];
sumx2 += xi * xi;
sumgwx += xi * wi * gi;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = x[i + r];
float wi = w[w_stride * (i + r)];
float gi = g[i + r];
sumx2 += xi * xi;
sumgwx += xi * wi * gi;
}
}
}
}
// Accumulate across threads
sumx2 = simd_sum(sumx2);
sumgwx = simd_sum(sumgwx);
if (simd_group_id == 0) {
local_sumx2[simd_lane_id] = 0;
local_sumgwx[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_lane_id == 0) {
local_sumx2[simd_group_id] = sumx2;
local_sumgwx[simd_group_id] = sumgwx;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
sumgwx = simd_sum(local_sumgwx[simd_lane_id]);
if (simd_lane_id == 0) {
local_meangwx[0] = sumgwx / axis_size;
local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float meangwx = local_meangwx[0];
float normalizer = local_normalizer[0];
float normalizer3 = normalizer * normalizer * normalizer;
// Write the outputs
gx += gid * axis_size + lid * N_READS;
gw += gid * axis_size + lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = x[i + r];
float wi = w[w_stride * (i + r)];
float gi = g[i + r];
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gw[i + r] = static_cast<T>(gi * xi * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = x[i + r];
float wi = w[w_stride * (i + r)];
float gi = g[i + r];
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gw[i + r] = static_cast<T>(gi * xi * normalizer);
}
}
}
}
}
// clang-format off
#define instantiate_rms_single_row(name, itype) \
template [[host_name("rms" #name)]] [[kernel]] void \
@@ -165,25 +375,56 @@ template <typename T, int N_READS = RMS_N_READS>
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms_looped(name, itype) \
template [[host_name("rms_looped" #name)]] [[kernel]] void \
rms_looped<itype>( \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
template [[host_name("vjp_rms" #name)]] [[kernel]] void \
vjp_rms_single_row<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms_looped(name, itype) \
template [[host_name("rms_looped" #name)]] [[kernel]] void \
rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \
vjp_rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms(name, itype) \
instantiate_rms_single_row(name, itype) \
instantiate_rms_looped(name, itype)

View File

@@ -5,7 +5,7 @@
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional>
template <typename T, bool traditional, bool forward>
[[kernel]] void rope(
const device T *in [[buffer(0)]],
device T * out [[buffer(1)]],
@@ -43,15 +43,22 @@ template <typename T, bool traditional>
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
float rx1 = x1 * costheta - x2 * sintheta;
float rx2 = x1 * sintheta + x2 * costheta;
float rx1;
float rx2;
if (forward) {
rx1 = x1 * costheta - x2 * sintheta;
rx2 = x1 * sintheta + x2 * costheta;
} else {
rx1 = x2 * sintheta + x1 * costheta;
rx2 = x2 * costheta - x1 * sintheta;
}
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
}
#define instantiate_rope(name, type, traditional) \
#define instantiate_rope(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] \
[[kernel]] void rope<type, traditional>( \
[[kernel]] void rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const size_t strides[3], \
@@ -62,9 +69,15 @@ template <typename T, bool traditional>
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
instantiate_rope(traditional_float16, half, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
instantiate_rope(traditional_float32, float, true)
instantiate_rope(float16, half, false)
instantiate_rope(bfloat16, bfloat16_t, false)
instantiate_rope(float32, float, false)
instantiate_rope(traditional_float16, half, true, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
instantiate_rope(traditional_float32, float, true, true)
instantiate_rope(float16, half, false, true)
instantiate_rope(bfloat16, bfloat16_t, false, true)
instantiate_rope(float32, float, false, true)
instantiate_rope(vjp_traditional_float16, half, true, false)
instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
instantiate_rope(vjp_traditional_float32, float, true, false)
instantiate_rope(vjp_float16, half, false, false)
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
instantiate_rope(vjp_float32, float, false, false)

View File

@@ -451,7 +451,7 @@ instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSu
//instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2)
instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4)
instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4)
//instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
//instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum)
//instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4)
instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4)
@@ -464,7 +464,7 @@ instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumP
//instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2)
instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4)
instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4)
//instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
//instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd)
//instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4)
instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4)
@@ -477,7 +477,7 @@ instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMa
//instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2)
instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4)
instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4)
//instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
//instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax)
//instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4)
instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4)
@@ -490,5 +490,5 @@ instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMi
//instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2)
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
//instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin)

View File

@@ -11,46 +11,48 @@ using namespace metal;
template <typename T>
inline T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause it is gonna be x
// will be in (-oo, 0] anyway and subsequently it will be divided by
// sum(exp(x_i)).
// Softmax doesn't need high precision exponential cause x is gonna be in
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
return fast::exp(x);
}
template <typename T, int N_READS = SOFTMAX_N_READS>
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_single_row(
const device T* in,
device T* out,
constant int& axis_size,
threadgroup T* local_max [[threadgroup(0)]],
threadgroup T* local_normalizer [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
T ld[N_READS];
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
AccT ld[N_READS];
in += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
ld[i] = in[i];
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] =
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min);
}
for (int i = 0; i < N_READS; i++) {
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
: Limits<AccT>::finite_min;
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<T>::finite_min;
local_max[simd_lane_id] = Limits<AccT>::finite_min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max
T maxval = Limits<T>::finite_min;
AccT maxval = Limits<AccT>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
@@ -69,9 +71,9 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
T normalizer = 0;
AccT normalizer = 0;
for (int i = 0; i < N_READS; i++) {
T exp_x = softmax_exp(ld[i] - maxval);
AccT exp_x = softmax_exp(ld[i] - maxval);
ld[i] = exp_x;
normalizer += exp_x;
}
@@ -92,25 +94,23 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
// Normalize and write to the output
out += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
out[i] = ld[i] * normalizer;
for (int i = 0; i < N_READS; i++) {
out[i] = T(ld[i] * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
out[i] = ld[i] * normalizer;
}
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
out[i] = T(ld[i] * normalizer);
}
}
}
}
template <typename T, int N_READS = SOFTMAX_N_READS>
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_looped(
const device T* in,
device T* out,
constant int& axis_size,
threadgroup T* local_max [[threadgroup(0)]],
threadgroup T* local_normalizer [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
@@ -118,22 +118,27 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * axis_size;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
// Get the max and the normalizer in one go
T prevmax;
T maxval = Limits<T>::finite_min;
T normalizer = 0;
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min;
AccT normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
T vals[N_READS];
AccT vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = in[offset + i];
vals[i] = AccT(in[offset + i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] =
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min);
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
}
}
prevmax = maxval;
@@ -179,50 +184,66 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
if (offset + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
for (int i = 0; i < N_READS; i++) {
out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if (offset + i < axis_size) {
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
out[offset + i] =
T(softmax_exp(in[offset + i] - maxval) * normalizer);
}
}
}
}
}
#define instantiate_softmax_single_row(name, itype) \
// clang-format off
#define instantiate_softmax(name, itype) \
template [[host_name("softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
threadgroup itype* local_max [[threadgroup(0)]], \
threadgroup itype* local_normalizer [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_looped(name, itype) \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
threadgroup itype* local_max [[threadgroup(0)]], \
threadgroup itype* local_normalizer [[threadgroup(1)]], \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax(name, itype) \
instantiate_softmax_single_row(name, itype) \
instantiate_softmax_looped(name, itype)
#define instantiate_softmax_precise(name, itype) \
template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
softmax_single_row<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \
softmax_looped<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
instantiate_softmax(float32, float)
instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t)
instantiate_softmax_precise(float16, half)
instantiate_softmax_precise(bfloat16, bfloat16_t)
// clang-format on

View File

@@ -394,7 +394,7 @@ struct Conv2DWeightBlockLoader {
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_->wt_strides[0]),
: src_ld(params_ -> wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),

View File

@@ -244,7 +244,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_->wt_strides[0]),
: src_ld(params_ -> wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),

View File

@@ -220,7 +220,7 @@ struct Conv2DWeightBlockLoaderGeneral {
const short base_ww_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_->wt_strides[0]),
: src_ld(params_ -> wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/erf.h"
#include "mlx/backend/metal/kernels/expm1f.h"
#include "mlx/backend/metal/kernels/utils.h"
namespace {
@@ -183,6 +184,13 @@ struct Exp {
}
};
struct Expm1 {
template <typename T>
T operator()(T x) {
return static_cast<T>(expm1f(static_cast<float>(x)));
};
};
struct Floor {
template <typename T>
T operator()(T x) {

View File

@@ -71,6 +71,7 @@ instantiate_unary_types(ceil, Ceil)
instantiate_unary_float(cos, Cos)
instantiate_unary_float(cosh, Cosh)
instantiate_unary_float(exp, Exp)
instantiate_unary_float(expm1, Expm1)
instantiate_unary_types(floor, Floor)
instantiate_unary_float(log, Log)
instantiate_unary_float(log2, Log2)

View File

@@ -197,8 +197,8 @@ inline auto collapse_batches(const array& a, const array& b) {
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
if (A_bshape != B_bshape) {
std::ostringstream msg;
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: "
<< "A " << a.shape() << ", B " << b.shape() << ".";
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ".";
throw std::runtime_error(msg.str());
}
@@ -227,9 +227,8 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
std::vector<int> C_bshape{c.shape().begin(), c.shape().end() - 2};
if (A_bshape != B_bshape || A_bshape != C_bshape) {
std::ostringstream msg;
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: "
<< "A " << a.shape() << ", B " << b.shape() << ", B " << c.shape()
<< ".";
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
throw std::runtime_error(msg.str());
}
@@ -332,11 +331,11 @@ void steel_matmul(
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch gemm kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -360,9 +359,9 @@ void steel_matmul(
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, C_split, 2);
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
@@ -380,8 +379,8 @@ void steel_matmul(
compute_encoder->setComputePipelineState(kernel);
// Set the arguments for the kernel
set_array_buffer(compute_encoder, C_split, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(C_split, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
compute_encoder->setBytes(&N, sizeof(int), 4);
@@ -422,11 +421,11 @@ void steel_matmul(
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -467,9 +466,9 @@ void steel_matmul(
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
// Launch kernel
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, out, 3);
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
@@ -488,7 +487,7 @@ void steel_matmul(
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (!is_floating_point(out.dtype())) {
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
@@ -622,7 +621,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
kname << "_nc" << !contiguous_kernel << "_axpby0";
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -630,9 +629,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(bn, bm, 1);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
set_array_buffer(compute_encoder, mat, 0);
set_array_buffer(compute_encoder, vec, 1);
set_array_buffer(compute_encoder, out, 3);
compute_encoder.set_input_array(mat, 0);
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
@@ -696,7 +695,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
if (!is_floating_point(out.dtype())) {
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
@@ -834,7 +833,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
kname << "_nc" << !contiguous_kernel << "_axpby1";
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -842,10 +841,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(bn, bm, 1);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
set_array_buffer(compute_encoder, mat, 0);
set_array_buffer(compute_encoder, vec, 1);
set_array_buffer(compute_encoder, c, 2);
set_array_buffer(compute_encoder, out, 3);
compute_encoder.set_input_array(mat, 0);
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_input_array(c, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
@@ -903,11 +902,11 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch gemm kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -931,9 +930,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, C_split, 2);
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
@@ -946,12 +945,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setComputePipelineState(kernel);
// Set the arguments for the kernel
set_array_buffer(compute_encoder, C_split, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(C_split, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
compute_encoder->setBytes(&N, sizeof(int), 4);
set_array_buffer(compute_encoder, c, 5);
compute_encoder.set_input_array(c, 5);
compute_encoder->setBytes(&ldc, sizeof(int), 6);
compute_encoder->setBytes(&fdc, sizeof(int), 7);
compute_encoder->setBytes(&alpha_, sizeof(float), 8);
@@ -992,12 +991,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned"
<< ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby");
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -1045,10 +1044,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());
// Launch kernel
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, c, 2);
set_array_buffer(compute_encoder, out, 3);
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(c, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
compute_encoder->setBytes(&params, sizeof(GEMMAddMMParams), 5);

View File

@@ -1,10 +1,10 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <future>
#include <memory>
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
@@ -74,6 +74,8 @@ std::function<void()> make_task(
if (arr.is_tracer()) {
inputs = arr.inputs();
}
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::vector<std::shared_ptr<array::Data>> buffers;
@@ -86,7 +88,6 @@ std::function<void()> make_task(
if (!arr.is_tracer()) {
arr.detach();
}
if (p) {
metal::device(s.device).end_encoding(s.index);
scheduler::notify_new_task(s);
@@ -108,4 +109,31 @@ std::function<void()> make_task(
return task;
}
bool start_capture(std::string path, id object) {
auto pool = new_scoped_memory_pool();
auto descriptor = MTL::CaptureDescriptor::alloc()->init();
descriptor->setCaptureObject(object);
if (path.length() > 0) {
auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding);
auto url = NS::URL::fileURLWithPath(string);
descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument);
descriptor->setOutputURL(url);
}
auto manager = MTL::CaptureManager::sharedCaptureManager();
return manager->startCapture(descriptor, nullptr);
}
bool start_capture(std::string path) {
auto& device = metal::device(mlx::core::Device::gpu);
return start_capture(path, device.mtl_device());
}
void stop_capture() {
auto manager = MTL::CaptureManager::sharedCaptureManager();
manager->stopCapture();
}
} // namespace mlx::core::metal

View File

@@ -2,15 +2,11 @@
#pragma once
#include <future>
#include <memory>
#include <vector>
#include "mlx/array.h"
#include "mlx/stream.h"
namespace mlx::core::metal {
/* Check if the Metal backend is available. */
bool is_available();
/* Get the actively used memory in bytes.
@@ -58,12 +54,8 @@ size_t set_memory_limit(size_t limit, bool relaxed = true);
* */
size_t set_cache_limit(size_t limit);
void new_stream(Stream stream);
std::shared_ptr<void> new_scoped_memory_pool();
std::function<void()> make_task(
array& arr,
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p);
/** Capture a GPU trace, saving it to an absolute file `path` */
bool start_capture(std::string path = "");
void stop_capture();
} // namespace mlx::core::metal

View File

@@ -0,0 +1,22 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <future>
#include <memory>
#include <vector>
#include "mlx/array.h"
#include "mlx/stream.h"
namespace mlx::core::metal {
void new_stream(Stream stream);
std::shared_ptr<void> new_scoped_memory_pool();
std::function<void()> make_task(
array& arr,
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p);
} // namespace mlx::core::metal

View File

@@ -4,6 +4,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/reduce.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
@@ -18,7 +19,7 @@ void RMSNorm::eval_gpu(
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) {
auto check_input = [&copies, &s](const array& x) -> const array& {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
@@ -27,10 +28,9 @@ void RMSNorm::eval_gpu(
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
copies.push_back(x_copy);
return x_copy;
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
const array& x = check_input(inputs[0]);
@@ -57,7 +57,7 @@ void RMSNorm::eval_gpu(
op_name += "_looped";
}
op_name += type_to_name(out);
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name);
@@ -79,10 +79,10 @@ void RMSNorm::eval_gpu(
uint32_t w_stride = w.strides()[0];
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(
compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder.set_input_array(
x.data_shared_ptr() == nullptr ? out : x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder->setBytes(&eps_, sizeof(float), 3);
compute_encoder->setBytes(&axis_size, sizeof(int), 4);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 5);
@@ -95,6 +95,113 @@ void RMSNorm::eval_gpu(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
void RMSNormVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& d = metal::device(s.device);
// Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler.
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& {
if (x.flags().row_contiguous) {
return x;
}
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
};
const array& x = check_input(inputs[0]);
const array& w = inputs[1];
const array& g = check_input(inputs[2]);
array& gx = outputs[0];
array& gw = outputs[1];
// Allocate space for the outputs
bool x_in_gx = false;
bool g_in_gx = false;
if (x.is_donatable()) {
gx.move_shared_buffer(x);
x_in_gx = true;
} else if (g.is_donatable()) {
gx.move_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
}
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
// Allocate a temporary to store the gradients for w and initialize the
// gradient accumulator to 0.
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
bool g_in_gw = false;
if (!g_in_gx && g.is_donatable()) {
gw_temp.move_shared_buffer(g);
g_in_gw = true;
} else {
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
}
copies.push_back(gw_temp);
{
array zero(0, gw.dtype());
copy_gpu(zero, gw, CopyType::Scalar, s);
copies.push_back(std::move(zero));
}
const int simd_size = 32;
const int n_reads = RMS_N_READS;
const int looped_limit = RMS_LOOPED_LIMIT;
std::string op_name = "vjp_rms";
if (axis_size > looped_limit) {
op_name += "_looped";
}
op_name += type_to_name(gx);
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
} else {
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
uint32_t w_stride = w.strides()[0];
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
compute_encoder.set_output_array(gx, 3);
compute_encoder.set_output_array(gw_temp, 4);
compute_encoder->setBytes(&eps_, sizeof(float), 5);
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
strided_reduce_general_dispatch(
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
void LayerNorm::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@@ -104,7 +211,7 @@ void LayerNorm::eval_gpu(
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) {
auto check_input = [&copies, &s](const array& x) -> const array& {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
@@ -113,10 +220,9 @@ void LayerNorm::eval_gpu(
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
copies.push_back(x_copy);
return x_copy;
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
const array& x = check_input(inputs[0]);
@@ -144,7 +250,7 @@ void LayerNorm::eval_gpu(
op_name += "_looped";
}
op_name += type_to_name(out);
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name);
@@ -167,11 +273,11 @@ void LayerNorm::eval_gpu(
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(
compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, b, 2);
set_array_buffer(compute_encoder, out, 3);
compute_encoder.set_input_array(
x.data_shared_ptr() == nullptr ? out : x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(b, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&eps_, sizeof(float), 4);
compute_encoder->setBytes(&axis_size, sizeof(int), 5);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6);
@@ -182,4 +288,131 @@ void LayerNorm::eval_gpu(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
void LayerNormVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& d = metal::device(s.device);
// Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler.
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& {
if (x.flags().row_contiguous) {
return x;
}
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
};
const array& x = check_input(inputs[0]);
const array& w = inputs[1];
const array& b = inputs[2];
const array& g = check_input(inputs[3]);
array& gx = outputs[0];
array& gw = outputs[1];
array& gb = outputs[2];
// Allocate space for the outputs
bool x_in_gx = false;
bool g_in_gx = false;
if (x.is_donatable()) {
gx.move_shared_buffer(x);
x_in_gx = true;
} else if (g.is_donatable()) {
gx.move_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
}
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
// Allocate a temporary to store the gradients for w and initialize the
// gradient accumulator to 0.
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
bool g_in_gw = false;
if (!g_in_gx && g.is_donatable()) {
gw_temp.move_shared_buffer(g);
g_in_gw = true;
} else {
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
}
copies.push_back(gw_temp);
{
array zero(0, gw.dtype());
copy_gpu(zero, gw, CopyType::Scalar, s);
copy_gpu(zero, gb, CopyType::Scalar, s);
copies.push_back(std::move(zero));
}
// Finish with the gradient for b in case we had a b
auto& compute_encoder = d.get_command_encoder(s.index);
if (gb.ndim() == 1 && gb.size() == axis_size) {
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
strided_reduce_general_dispatch(
g_in_gx ? gx : (g_in_gw ? gw_temp : g),
gb,
"sum",
plan,
{0},
compute_encoder,
d,
s);
}
const int simd_size = 32;
const int n_reads = RMS_N_READS;
const int looped_limit = RMS_LOOPED_LIMIT;
std::string op_name = "vjp_layer_norm";
if (axis_size > looped_limit) {
op_name += "_looped";
}
op_name += type_to_name(gx);
{
auto kernel = d.get_kernel(op_name);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
} else {
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
compute_encoder.set_output_array(gx, 3);
compute_encoder.set_output_array(gw_temp, 4);
compute_encoder->setBytes(&eps_, sizeof(float), 5);
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
if (gw.ndim() == 1 && gw.size() == axis_size) {
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
strided_reduce_general_dispatch(
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core::fast

View File

@@ -68,18 +68,18 @@ void binary_op(
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_a ? outputs[0] : a, 0);
set_array_buffer(
compute_encoder, donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
set_array_buffer(compute_encoder, outputs[0], 2);
set_array_buffer(compute_encoder, outputs[1], 3);
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
compute_encoder.set_input_array(
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
compute_encoder.set_output_array(outputs[0], 2);
compute_encoder.set_output_array(outputs[1], 3);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
@@ -167,13 +167,13 @@ void binary_op(
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_a ? out : a, 0);
set_array_buffer(compute_encoder, donate_b ? out : b, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder.set_input_array(donate_a ? out : a, 0);
compute_encoder.set_input_array(donate_b ? out : b, 1);
compute_encoder.set_output_array(out, 2);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
@@ -253,12 +253,12 @@ void ternary_op(
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, c, 2);
set_array_buffer(compute_encoder, out, 3);
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(c, 2);
compute_encoder.set_output_array(out, 3);
if (topt == TernaryOpType::General) {
auto ndim = shape.size();
@@ -339,11 +339,11 @@ void unary_op(
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1);
if (!contig) {
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
compute_encoder->setBytes(
@@ -365,7 +365,7 @@ void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
}
template <typename T>
void arange_set_scalars(T start, T next, MTL::ComputeCommandEncoder* enc) {
void arange_set_scalars(T start, T next, CommandEncoder& enc) {
enc->setBytes(&start, sizeof(T), 0);
T step = next - start;
enc->setBytes(&step, sizeof(T), 1);
@@ -384,7 +384,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size group_dims = MTL::Size(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
switch (out.dtype()) {
@@ -427,7 +427,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
}
set_array_buffer(compute_encoder, out, 2);
compute_encoder.set_output_array(out, 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
@@ -487,7 +487,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// ArgReduce
int simd_size = 32;
int n_reads = 4;
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name + type_to_name(in));
NS::UInteger thread_group_size = std::min(
@@ -502,8 +502,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
if (ndim == 0) {
// Pass place holders so metal doesn't complain
int shape_ = 0;
@@ -552,6 +552,9 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
auto& d = metal::device(stream().device);
auto& compute_encoder = d.get_command_encoder(stream().index);
auto concurrent_ctx = compute_encoder.start_concurrent();
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis_] * sizes[i];
@@ -615,6 +618,10 @@ void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "exp");
}
void Expm1::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "expm1");
}
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0];
CopyType ctype;
@@ -787,10 +794,10 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, keys, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(keys, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&odd, sizeof(bool), 2);
compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3);
@@ -822,7 +829,7 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (not is_integral(in.dtype())) {
if (issubdtype(in.dtype(), inexact)) {
unary_op(inputs, out, "round");
} else {
// No-op integer types

View File

@@ -48,7 +48,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
<< bits_ << "_fast";
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -57,11 +57,11 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
set_array_buffer(compute_encoder, w, 0);
set_array_buffer(compute_encoder, scales, 1);
set_array_buffer(compute_encoder, biases, 2);
set_array_buffer(compute_encoder, x, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
@@ -75,7 +75,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
<< bits_;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -84,11 +84,11 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
set_array_buffer(compute_encoder, w, 0);
set_array_buffer(compute_encoder, scales, 1);
set_array_buffer(compute_encoder, biases, 2);
set_array_buffer(compute_encoder, x, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
@@ -102,7 +102,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
<< bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -114,11 +114,11 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, scales, 2);
set_array_buffer(compute_encoder, biases, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&B, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
@@ -133,20 +133,20 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
<< bits_;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = std::min(32, O);
int bo = 8;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, scales, 2);
set_array_buffer(compute_encoder, biases, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
@@ -160,7 +160,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
<< bits_;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -179,11 +179,11 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(msg.str());
}
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, scales, 2);
set_array_buffer(compute_encoder, biases, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&B, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);

View File

@@ -4,10 +4,10 @@
#include <cassert>
#include <sstream>
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/reduce.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -18,8 +18,6 @@ namespace mlx::core {
// Case wise reduce dispatch
//////////////////////////////////////////////////////////////////////
namespace {
inline auto safe_div(size_t n, size_t m) {
return m == 0 ? 0 : (n + m - 1) / m;
}
@@ -37,7 +35,7 @@ void all_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
MTL::ComputeCommandEncoder* compute_encoder,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
Dtype out_dtype = out.dtype();
@@ -73,8 +71,8 @@ void all_reduce_dispatch(
// Encode buffers and dispatch
if (is_out_64b_int == false || n_thread_groups == 1) {
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
@@ -87,14 +85,14 @@ void all_reduce_dispatch(
std::vector<array> intermediates = {intermediate};
// First dispatch
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Second pass to reduce intermediate reduction results written to DRAM
set_array_buffer(compute_encoder, intermediate, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2);
mod_in_size = (intermediate_size + n_reads - 1) / n_reads;
@@ -125,7 +123,7 @@ void row_reduce_general_dispatch(
const std::string& op_name,
const ReductionPlan& plan,
const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
Dtype out_dtype = out.dtype();
@@ -210,8 +208,8 @@ void row_reduce_general_dispatch(
// Dispatch kernel
if (!is_out_64b_int || non_row_reductions == 1) {
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4);
@@ -232,8 +230,8 @@ void row_reduce_general_dispatch(
std::vector<array> intermediates = {intermediate};
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4);
@@ -260,8 +258,8 @@ void row_reduce_general_dispatch(
ndim = new_shape.size();
// Set the arguments for the kernel
set_array_buffer(compute_encoder, intermediate, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4);
@@ -303,7 +301,7 @@ void strided_reduce_general_dispatch(
const std::string& op_name,
const ReductionPlan& plan,
const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
Dtype out_dtype = out.dtype();
@@ -351,8 +349,8 @@ void strided_reduce_general_dispatch(
}
// Encode arrays
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
@@ -417,8 +415,8 @@ void strided_reduce_general_dispatch(
if (is_out_64b_int == false) {
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
@@ -452,8 +450,8 @@ void strided_reduce_general_dispatch(
std::vector<array> intermediates = {intermediate};
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
@@ -496,8 +494,8 @@ void strided_reduce_general_dispatch(
"row_reduce_general_no_atomics_" + op_name +
type_to_name(intermediate));
compute_encoder->setComputePipelineState(row_reduce_kernel);
set_array_buffer(compute_encoder, intermediate, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 4);
@@ -534,8 +532,6 @@ void strided_reduce_general_dispatch(
}
}
} // namespace
//////////////////////////////////////////////////////////////////////
// Main reduce dispatch
//////////////////////////////////////////////////////////////////////
@@ -577,7 +573,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Initialize output
auto& s = stream();
auto& d = metal::device(s.device);
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel("i" + op_name + type_to_name(out));
size_t nthreads = out.size();
@@ -588,7 +584,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, out, 0);
compute_encoder.set_output_array(out, 0);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}

View File

@@ -0,0 +1,41 @@
// Copyright @ 2023 - 2024 Apple Inc.
#pragma once
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/metal/device.h"
#include "mlx/stream.h"
namespace mlx::core {
using metal::CommandEncoder;
void all_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s);
void row_reduce_general_dispatch(
const array& in,
array& out,
const std::string& op_name,
const ReductionPlan& plan,
const std::vector<int>& axes,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s);
void strided_reduce_general_dispatch(
const array& in,
array& out,
const std::string& op_name,
const ReductionPlan& plan,
const std::vector<int>& axes,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s);
} // namespace mlx::core

View File

@@ -63,14 +63,15 @@ void RoPE::eval_gpu(
out_strides[2] = out.strides()[ndim - 1];
std::ostringstream kname;
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
kname << "rope_" << (forward_ ? "" : "vjp_")
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
float base = std::log2(base_);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, donated ? out : in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 3);
compute_encoder->setBytes(&offset_, sizeof(int), 4);

View File

@@ -71,7 +71,7 @@ void sdpa_metal(
std::string kname_suffix = kname_suffix_tile_size + kname_suffix_nsimdgroups;
kname_partials << kname_suffix;
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname_partials.str());
compute_encoder->setComputePipelineState(kernel);
@@ -87,15 +87,15 @@ void sdpa_metal(
MLXScaledDotProductAttentionParams params{
query_sequence_length, n_q_heads, n_kv_heads, n_tiles, alpha};
set_array_buffer(compute_encoder, q, 0);
set_array_buffer(compute_encoder, k, 1);
set_array_buffer(compute_encoder, v, 2);
compute_encoder.set_input_array(q, 0);
compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(v, 2);
compute_encoder->setBytes(&KV_sequence_length, sizeof(KV_sequence_length), 3);
compute_encoder->setBytes(
&params, sizeof(MLXScaledDotProductAttentionParams), 4);
set_array_buffer(compute_encoder, o_partial, 5);
set_array_buffer(compute_encoder, p_lse, 6);
set_array_buffer(compute_encoder, p_rowmaxes, 7);
compute_encoder.set_input_array(o_partial, 5);
compute_encoder.set_input_array(p_lse, 6);
compute_encoder.set_input_array(p_rowmaxes, 7);
constexpr const uint tgroupMemorySize = 32768;
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
@@ -104,12 +104,12 @@ void sdpa_metal(
{
auto kernel_accum = d.get_kernel(kname_reduce.str());
compute_encoder->setComputePipelineState(kernel_accum);
set_array_buffer(compute_encoder, o_partial, 0);
set_array_buffer(compute_encoder, p_lse, 1);
set_array_buffer(compute_encoder, p_rowmaxes, 2);
compute_encoder.set_input_array(o_partial, 0);
compute_encoder.set_input_array(p_lse, 1);
compute_encoder.set_input_array(p_rowmaxes, 2);
compute_encoder->setBytes(
&params, sizeof(MLXScaledDotProductAttentionParams), 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder.set_output_array(out, 4);
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
@@ -127,7 +127,7 @@ void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
assert(inputs.size() >= 3);
if (!is_floating_point(out.dtype())) {
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[ScaledDotProductAttention] Does not yet support non-floating point types.");
}

View File

@@ -52,10 +52,10 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
kname << type_to_name(in) << "_" << type_to_name(out);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
size_t size = in.shape(axis_);
compute_encoder->setBytes(&size, sizeof(size_t), 2);
@@ -101,10 +101,10 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
kname << type_to_name(in) << "_" << type_to_name(out);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
size_t size = in.shape(axis_);
size_t stride = in.strides()[axis_];
compute_encoder->setBytes(&size, sizeof(size_t), 2);

View File

@@ -12,7 +12,7 @@ namespace mlx::core {
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (!is_floating_point(out.dtype())) {
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[softmax] Does not support non-floating point types.");
}
@@ -21,7 +21,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) {
auto check_input = [&copies, &s](const array& x) -> const array& {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
@@ -30,10 +30,9 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
copies.push_back(x_copy);
return x_copy;
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
const array& in = check_input(inputs[0]);
@@ -57,8 +56,11 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
if (axis_size > looped_limit) {
op_name += "looped_";
}
if (in.dtype() != float32 && precise_) {
op_name += "precise_";
}
op_name += type_to_name(out);
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name);
@@ -79,13 +81,10 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
}
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0);
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
@@ -57,13 +57,13 @@ void single_block_sort(
}
// Prepare command encoder
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Set inputs
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
@@ -102,6 +102,11 @@ void multi_block_sort(
int nc_dim = nc_shape.size();
if (nc_dim == 0) {
nc_shape = {0};
nc_str = {1};
}
int size_sorted_axis = in.shape(axis);
int stride_sorted_axis = in.strides()[axis];
@@ -126,7 +131,7 @@ void multi_block_sort(
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
// Prepare command encoder
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
// Do blockwise sort
{
@@ -137,14 +142,15 @@ void multi_block_sort(
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, dev_vals_0, 1);
set_array_buffer(compute_encoder, dev_idxs_0, 2);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(dev_vals_0, 1);
compute_encoder.set_output_array(dev_idxs_0, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 7);
compute_encoder->setBytes(
nc_shape.data(), nc_shape.size() * sizeof(int), 6);
compute_encoder->setBytes(nc_str.data(), nc_str.size() * sizeof(size_t), 7);
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
@@ -158,7 +164,8 @@ void multi_block_sort(
array dev_idxs_in = dev_idxs_0;
array dev_vals_out = dev_vals_1;
array dev_idxs_out = dev_idxs_1;
for (int merge_tiles = 2; merge_tiles <= n_blocks; merge_tiles *= 2) {
for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
dev_vals_out = ping ? dev_vals_0 : dev_vals_1;
@@ -174,9 +181,9 @@ void multi_block_sort(
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, block_partitions, 0);
set_array_buffer(compute_encoder, dev_vals_in, 1);
set_array_buffer(compute_encoder, dev_idxs_in, 2);
compute_encoder.set_output_array(block_partitions, 0);
compute_encoder.set_input_array(dev_vals_in, 1);
compute_encoder.set_input_array(dev_idxs_in, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
@@ -195,11 +202,11 @@ void multi_block_sort(
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, block_partitions, 0);
set_array_buffer(compute_encoder, dev_vals_in, 1);
set_array_buffer(compute_encoder, dev_idxs_in, 2);
set_array_buffer(compute_encoder, dev_vals_out, 3);
set_array_buffer(compute_encoder, dev_idxs_out, 4);
compute_encoder.set_input_array(block_partitions, 0);
compute_encoder.set_input_array(dev_vals_in, 1);
compute_encoder.set_input_array(dev_idxs_in, 2);
compute_encoder.set_output_array(dev_vals_out, 3);
compute_encoder.set_output_array(dev_idxs_out, 4);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 6);
compute_encoder->setBytes(&n_blocks, sizeof(int), 7);

View File

@@ -1,37 +1,20 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
inline void
set_array_buffer(MTL::ComputeCommandEncoder* enc, const array& a, int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
enc->setBuffer(a_buf, offset, idx);
}
inline void set_array_buffer(
MTL::ComputeCommandEncoder* enc,
const array& a,
int64_t offset,
int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
using metal::CommandEncoder;
template <typename T>
inline void set_vector_bytes(
MTL::ComputeCommandEncoder* enc,
CommandEncoder& enc,
const std::vector<T>& vec,
size_t nelems,
int idx) {
@@ -39,10 +22,8 @@ inline void set_vector_bytes(
}
template <typename T>
inline void set_vector_bytes(
MTL::ComputeCommandEncoder* enc,
const std::vector<T>& vec,
int idx) {
inline void
set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
return set_vector_bytes(enc, vec, vec.size(), idx);
}
@@ -123,6 +104,32 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
}
inline NS::String* make_string(std::ostringstream& os) {
std::string string = os.str();
return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
}
inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
#ifdef MLX_METAL_DEBUG
std::ostringstream label;
label << "Stream " << index;
queue->setLabel(make_string(label));
#endif
}
inline void debug_set_primitive_buffer_label(
MTL::CommandBuffer* command_buffer,
Primitive& primitive) {
#ifdef MLX_METAL_DEBUG
std::ostringstream label;
if (auto cbuf_label = command_buffer->label(); cbuf_label) {
label << cbuf_label->utf8String();
}
primitive.print(label);
command_buffer->setLabel(make_string(label));
#endif
}
} // namespace
} // namespace mlx::core

View File

@@ -3,6 +3,7 @@
#include <stdexcept>
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
namespace mlx::core::metal {
@@ -39,5 +40,9 @@ size_t set_memory_limit(size_t, bool) {
size_t set_cache_limit(size_t) {
return 0;
}
bool start_capture(std::string path) {
return false;
}
void stop_capture() {}
} // namespace mlx::core::metal

View File

@@ -49,6 +49,7 @@ NO_GPU(Equal)
NO_GPU(Erf)
NO_GPU(ErfInv)
NO_GPU(Exp)
NO_GPU(Expm1)
NO_GPU(FFT)
NO_GPU(Floor)
NO_GPU(Full)
@@ -103,7 +104,9 @@ NO_GPU(Inverse)
namespace fast {
NO_GPU_MULTI(LayerNorm)
NO_GPU_MULTI(LayerNormVJP)
NO_GPU_MULTI(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE)
NO_GPU(ScaledDotProductAttention)
} // namespace fast

View File

@@ -32,7 +32,7 @@ bool is_unary(const Primitive& p) {
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
typeid(p) == typeid(Tanh));
typeid(p) == typeid(Tanh) || typeid(p) == typeid(Expm1));
}
bool is_binary(const Primitive& p) {
@@ -162,44 +162,51 @@ CompileMode& compile_mode() {
return compile_mode_;
}
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
using ParentsMap =
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
// Helper like below but only merges the two provided arrays. If the src has
// siblings then these won't be merged to the dst.
void merge_one(array& dst, array& src, ParentsMap& parents_map) {
auto src_parents = parents_map.find(src.id());
if (src_parents == parents_map.end()) {
return;
}
auto& pairs = parents_map[dst.id()];
for (auto& parent : src_parents->second) {
parent.first.inputs()[parent.second] = dst;
pairs.push_back(parent);
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
};
// Helper that merges two arrays in the graph by setting the parents of the
// source to point to the destination
// source to point to the destination. The arrays are assumed to be coming from
// equivalent primitives so their siblings are merged as well.
void merge(array& dst, array& src, ParentsMap& parents_map) {
// Canonicalize the order of the primitives outputs
auto sources = src.outputs();
auto dests = dst.outputs();
// For each src parent, point it to the corresponding dst
for (int i = 0; i < sources.size(); ++i) {
auto src_parents = parents_map.find(sources[i].id());
if (src_parents == parents_map.end()) {
continue;
}
auto& pairs = parents_map[dests[i].id()];
for (auto& parent : src_parents->second) {
parent.first.inputs()[parent.second] = dests[i];
pairs.push_back(parent);
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
merge_one(dests[i], sources[i], parents_map);
}
};
template <typename T, typename... U>
size_t getAddress(std::function<T(U...)> f) {
typedef T(fnType)(U...);
fnType** fnPointer = f.template target<fnType*>();
if (fnPointer == nullptr) {
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
using FunType = T (*)(U...);
const FunType* fun_ptr = fun.template target<FunType>();
if (fun_ptr == nullptr) {
throw std::invalid_argument(
"[compile] Cannot compile a non-addressable function.");
}
return (size_t)*fnPointer;
return reinterpret_cast<std::uintptr_t>(*fun_ptr);
}
struct CompilerCache {
class CompilerCache {
public:
struct CacheEntry {
std::vector<array> inputs;
std::vector<array> outputs;
@@ -211,20 +218,20 @@ struct CompilerCache {
// Returns a reference to a CacheEntry which can be updated
// by the caller to avoid copying large tapes / inputs / outputs
CacheEntry& find(
size_t fun_id,
std::uintptr_t fun_id,
const std::vector<array>& inputs,
bool shapeless,
const std::vector<uint64_t>& constants) {
// Try to find the entry
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
auto& entries = entry_it->second;
auto is_match = [shapeless](
const std::vector<array>& in1,
const std::vector<array>& in2) {
// Find the cache entries for |fun_id|.
std::vector<CacheEntry>& entries = cache_[fun_id];
// Compare if 2 arrays have same shape and dtype.
auto has_same_shape_and_dtype = [shapeless](
const std::vector<array>& in1,
const std::vector<array>& in2) {
if (in1.size() != in2.size()) {
return false;
}
for (int i = 0; i < in1.size(); ++i) {
for (size_t i = 0; i < in1.size(); ++i) {
if (in1[i].ndim() != in2[i].ndim()) {
return false;
}
@@ -237,14 +244,14 @@ struct CompilerCache {
}
return true;
};
// Loop over entries and check inputs match i.e. shapes and types must be
// equal. Note this could get really slow if one compiles the same
// function with many different shapes. May want to store entries in a
// more easily searchable structure.
for (auto& entry : entries) {
for (CacheEntry& entry : entries) {
// Check the inputs match and return if so
if (is_match(inputs, entry.inputs) && constants == entry.constants) {
if (has_same_shape_and_dtype(inputs, entry.inputs) &&
constants == entry.constants) {
return entry;
}
}
@@ -253,7 +260,7 @@ struct CompilerCache {
return entries.back();
};
void erase(size_t fun_id) {
void erase(std::uintptr_t fun_id) {
cache_.erase(fun_id);
}
@@ -263,8 +270,9 @@ struct CompilerCache {
// initialized before the compiler cache
allocator::allocator();
}
friend CompilerCache& compiler_cache();
std::unordered_map<size_t, std::vector<CacheEntry>> cache_;
std::unordered_map<std::uintptr_t, std::vector<CacheEntry>> cache_;
};
CompilerCache& compiler_cache() {
@@ -523,9 +531,14 @@ void compile_fuse(
// - Collect inputs to the new compiled primitive
// - Add fusable primitives to a tape in the correct order
std::function<void(const array&, int, const Stream&)> recurse;
std::function<void(
const array&, int, const Stream&, const std::vector<int>&)>
recurse;
std::unordered_set<uintptr_t> cache;
recurse = [&](const array& a, int depth, const Stream& s) {
recurse = [&](const array& a,
int depth,
const Stream& s,
const std::vector<int>& shape) {
if (cache.find(a.id()) != cache.end()) {
return;
}
@@ -535,8 +548,10 @@ void compile_fuse(
// - Constant input
// - Stream mismatch
// - Non fusable primitive
// - Is global output but has a different shape
if (depth >= max_compile_depth || !a.has_primitive() ||
a.primitive().stream() != s || !is_fusable(a.primitive())) {
a.primitive().stream() != s || !is_fusable(a.primitive()) ||
(output_map.find(a.id()) != output_map.end() && a.shape() != shape)) {
return;
}
@@ -563,13 +578,13 @@ void compile_fuse(
cache.insert({a.id()});
for (auto& in : a.inputs()) {
recurse(in, depth + 1, s);
recurse(in, depth + 1, s, shape);
}
};
if (arr.has_primitive()) {
Stream s = arr.primitive().stream();
recurse(arr, 0, s);
recurse(arr, 0, s, arr.shape());
}
// Not worth fusing a single primitive
@@ -633,6 +648,10 @@ void compile_fuse(
std::vector<std::vector<int>> shapes;
std::vector<Dtype> types;
for (auto& o : old_outputs) {
if (o.shape() != old_outputs.back().shape()) {
throw std::runtime_error(
"[compile] Compilation failed. Tried to fuse operations with different output shapes");
}
shapes.push_back(o.shape());
types.push_back(o.dtype());
}
@@ -675,7 +694,7 @@ void compile_fuse(
// - Update outputs parents to point to compiled outputs
// - Update any overall graph outputs to be compiled outputs
for (int o = 0; o < old_outputs.size(); ++o) {
merge(compiled_outputs[o], old_outputs[o], parents_map);
merge_one(compiled_outputs[o], old_outputs[o], parents_map);
if (auto it = output_map.find(old_outputs[o].id());
it != output_map.end()) {
it->second = compiled_outputs[o];
@@ -774,7 +793,7 @@ void compile_validate_shapeless(const std::vector<array>& tape) {
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
size_t fun_id,
std::uintptr_t fun_id,
bool shapeless /* = false */,
std::vector<uint64_t> constants /* = {} */) {
if (compile_mode() == CompileMode::disabled ||
@@ -833,7 +852,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
};
}
void compile_erase(size_t fun_id) {
void compile_erase(std::uintptr_t fun_id) {
detail::compiler_cache().erase(fun_id);
}
@@ -845,7 +864,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
if (detail::compile_mode() == CompileMode::disabled) {
return fun;
}
auto fun_id = detail::getAddress(fun);
auto fun_id = detail::get_function_address(fun);
return detail::compile(fun, fun_id, shapeless);
}

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cstdint>
#include <sstream>
@@ -12,6 +12,7 @@ namespace mlx::core {
namespace {
constexpr int num_types = 13;
constexpr int num_cats = 8;
constexpr Dtype::Kind type_kinds[num_types] = {
Dtype::Kind::b, // bool_,
@@ -49,18 +50,37 @@ constexpr Dtype type_rules[num_types][num_types] = {
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64
};
constexpr bool subcategory_to_category[num_cats][num_cats] = {
// complexfloating floating inexact signedinteger unsignedinteger integer number generic
{true, false, true, false, false, false, true, true}, // complexfloating
{false, true, true, false, false, false, true, true}, // floating
{false, false, true, false, false, false, true, true}, // inexact
{false, false, false, true, false, true, true, true}, // signedinteger
{false, false, false, false, true, true, true, true}, // unsignedinteger
{false, false, false, false, false, true, true, true}, // integer
{false, false, false, false, false, false, true, true}, // number
{false, false, false, false, false, false, false, true}, // generic
};
constexpr Dtype::Category type_to_category[num_types] = {
Dtype::Category::generic, // bool_,
Dtype::Category::unsignedinteger, // uint8,
Dtype::Category::unsignedinteger, // uint16,
Dtype::Category::unsignedinteger, // uint32,
Dtype::Category::unsignedinteger, // uint64,
Dtype::Category::signedinteger, // int8,
Dtype::Category::signedinteger, // int16,
Dtype::Category::signedinteger, // int32,
Dtype::Category::signedinteger, // int64,
Dtype::Category::floating, // float16,
Dtype::Category::floating, // float32,
Dtype::Category::floating, // bfloat16,
Dtype::Category::complexfloating, // complex64,
};
// clang-format on
inline bool is_big_endian() {
union ByteOrder {
int32_t i;
uint8_t c[4];
};
ByteOrder b = {0x01234567};
return b.c[0] == 0x01;
}
} // namespace
Dtype promote_types(const Dtype& t1, const Dtype& t2) {
@@ -141,6 +161,23 @@ TypeToDtype<complex64_t>::operator Dtype() {
return complex64;
}
bool issubdtype(const Dtype& a, const Dtype& b) {
return a == b;
}
bool issubdtype(const Dtype::Category& cat, const Dtype& type) {
return false;
}
bool issubdtype(const Dtype& type, const Dtype::Category& cat) {
return issubdtype(type_to_category[static_cast<uint32_t>(type.val)], cat);
}
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
return subcategory_to_category[static_cast<uint32_t>(a)]
[static_cast<uint32_t>(b)];
}
// Array protocol typestring for Dtype
std::string dtype_to_array_protocol(const Dtype& t) {
std::ostringstream r;
@@ -153,9 +190,9 @@ std::string dtype_to_array_protocol(const Dtype& t) {
}
// Dtype from array protocol type string
Dtype dtype_from_array_protocol(const std::string& t) {
Dtype dtype_from_array_protocol(std::string_view t) {
if (t.length() == 2 || t.length() == 3) {
std::string r = t.length() == 3 ? t.substr(1, 2) : t;
std::string_view r = t.length() == 3 ? t.substr(1, 2) : t;
if (r == "V2") {
return bfloat16;
@@ -201,7 +238,7 @@ Dtype dtype_from_array_protocol(const std::string& t) {
}
throw std::invalid_argument(
"[from_str] Invalid array protocol type-string: " + t);
"[from_str] Invalid array protocol type-string: " + std::string(t));
}
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
@@ -38,6 +38,17 @@ struct Dtype {
V, /* void - used for brain float */
};
enum class Category {
complexfloating,
floating,
inexact,
signedinteger,
unsignedinteger,
integer,
number,
generic
};
Val val;
const uint8_t size;
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
@@ -63,6 +74,22 @@ inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
inline constexpr Dtype::Category complexfloating =
Dtype::Category::complexfloating;
inline constexpr Dtype::Category floating = Dtype::Category::floating;
inline constexpr Dtype::Category inexact = Dtype::Category::inexact;
inline constexpr Dtype::Category signedinteger = Dtype::Category::signedinteger;
inline constexpr Dtype::Category unsignedinteger =
Dtype::Category::unsignedinteger;
inline constexpr Dtype::Category integer = Dtype::Category::integer;
inline constexpr Dtype::Category number = Dtype::Category::number;
inline constexpr Dtype::Category generic = Dtype::Category::generic;
bool issubdtype(const Dtype& a, const Dtype& b);
bool issubdtype(const Dtype::Category& a, const Dtype& b);
bool issubdtype(const Dtype& a, const Dtype::Category& b);
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b);
Dtype promote_types(const Dtype& t1, const Dtype& t2);
inline uint8_t size_of(const Dtype& t) {
@@ -71,23 +98,6 @@ inline uint8_t size_of(const Dtype& t) {
Dtype::Kind kindof(const Dtype& t);
inline bool is_unsigned(const Dtype& t) {
return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b;
}
inline bool is_floating_point(const Dtype& t) {
return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V ||
kindof(t) == Dtype::Kind::c;
}
inline bool is_complex(const Dtype& t) {
return kindof(t) == Dtype::Kind::c;
}
inline bool is_integral(const Dtype& t) {
return !(is_floating_point(t));
}
template <typename T>
struct TypeToDtype {
operator Dtype();
@@ -96,6 +106,6 @@ struct TypeToDtype {
// Array protocol typestring for Dtype
std::string dtype_to_array_protocol(const Dtype& t);
// Dtype from array protocol type string
Dtype dtype_from_array_protocol(const std::string& t);
Dtype dtype_from_array_protocol(std::string_view t);
} // namespace mlx::core

View File

@@ -1,5 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <numeric>
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include "mlx/ops.h"
@@ -64,7 +67,7 @@ array rms_norm(
throw std::invalid_argument(msg.str());
}
auto out_type = result_type(x, weight);
if (!is_floating_point(out_type) || is_complex(out_type)) {
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[rms_norm] Received unsupported type " << out_type << ".";
throw std::invalid_argument(msg.str());
@@ -94,11 +97,69 @@ array rms_norm(
return fallback({x, weight})[0];
}
std::vector<array> RMSNorm::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
assert(primals.size() == 2);
assert(outputs.size() == 1);
assert(cotangents.size() == 1);
auto s = stream();
auto fallback = [eps = eps_, s](const std::vector<array>& inputs) {
auto& x = inputs[0];
auto& w = inputs[1];
auto& g = inputs[2];
std::vector<array> vjps;
auto n = rsqrt(
add(mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s),
array(eps, x.dtype()),
s),
s);
auto n3 = power(n, array(3, x.dtype()), s);
// df/dx
auto gw = multiply(g, w, s);
auto t = mean(multiply(gw, x, s), /* axis= */ -1, /* keepdims= */ true, s);
t = multiply(multiply(x, t, s), n3, s);
vjps.push_back(subtract(multiply(gw, n, s), t, s));
// df/dw
std::vector<int> axes(g.ndim() - 1);
std::iota(axes.begin(), axes.end(), 0);
vjps.push_back(
sum(multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s));
return vjps;
};
auto vjps = array::make_arrays(
{primals[0].shape(), primals[1].shape()},
{primals[0].dtype(), primals[1].dtype()},
std::make_shared<RMSNormVJP>(s, fallback, eps_),
{primals[0], primals[1], cotangents[0]});
std::vector<array> returned_vjps;
for (auto& arg : argnums) {
returned_vjps.push_back(std::move(vjps[arg]));
}
return returned_vjps;
}
bool RMSNorm::is_equivalent(const Primitive& other) const {
const RMSNorm& a_other = static_cast<const RMSNorm&>(other);
return eps_ == a_other.eps_;
}
bool RMSNormVJP::is_equivalent(const Primitive& other) const {
const RMSNormVJP& a_other = static_cast<const RMSNormVJP&>(other);
return eps_ == a_other.eps_;
}
array layer_norm(
const array& x,
const std::optional<array>& weight,
@@ -128,7 +189,7 @@ array layer_norm(
? ((bias.has_value()) ? result_type(x, *weight, *bias)
: result_type(x, *weight))
: x.dtype();
if (!is_floating_point(out_type) || is_complex(out_type)) {
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[layer_norm] Received unsupported type " << out_type << ".";
throw std::invalid_argument(msg.str());
@@ -176,11 +237,90 @@ array layer_norm(
return fallback({x, passed_weight, passed_bias})[0];
}
std::vector<array> LayerNorm::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
assert(primals.size() == 3);
assert(outputs.size() == 1);
assert(cotangents.size() == 1);
auto s = stream();
auto fallback = [eps = eps_, s](const std::vector<array>& inputs) {
auto& x = inputs[0];
auto& w = inputs[1];
auto& b = inputs[2];
auto& g = inputs[3];
std::vector<array> vjps;
auto norm = number_of_elements(x, {-1}, true, x.dtype(), s);
auto sumx = sum(x, /* axis= */ -1, /* keepdims= */ true, s);
auto sumx2 = sum(square(x, s), /* axis= */ -1, /* keepdims= */ true, s);
auto mu = multiply(sumx, norm, s);
auto mu2 = multiply(sumx2, norm, s);
auto var = subtract(mu2, square(mu, s), s);
auto n = rsqrt(add(var, array(eps, x.dtype()), s));
auto n3 = power(n, array(3, x.dtype()), s);
auto x_c = subtract(x, mu, s);
// df/dx
auto wg = multiply(w, g, s);
auto sumwg =
multiply(sum(wg, /* axis= */ -1, /* keepdims= */ true, s), norm, s);
auto sumwgxc = multiply(
sum(multiply(wg, x_c, s), /* axis= */ -1, /* keepdims= */ true, s),
norm,
s);
auto t1 = multiply(multiply(x_c, sumwgxc, s), n3, s);
auto t2 = multiply(subtract(wg, sumwg, s), n, s);
vjps.push_back(subtract(t2, t1, s));
// df/dw
std::vector<int> axes(g.ndim() - 1);
std::iota(axes.begin(), axes.end(), 0);
if (w.ndim() == 0) {
vjps.push_back(zeros_like(w, s));
} else {
vjps.push_back(sum(
multiply(g, multiply(x_c, n, s), s), axes, /* keepdims= */ false, s));
}
// df/db
if (b.ndim() == 0) {
vjps.push_back(zeros_like(w, s));
} else {
vjps.push_back(sum(g, axes, /* keepdims= */ false, s));
}
return vjps;
};
auto vjps = array::make_arrays(
{primals[0].shape(), primals[1].shape(), primals[2].shape()},
{primals[0].dtype(), primals[1].dtype(), primals[2].dtype()},
std::make_shared<LayerNormVJP>(s, fallback, eps_),
{primals[0], primals[1], primals[2], cotangents[0]});
std::vector<array> returned_vjps;
for (auto& arg : argnums) {
returned_vjps.push_back(std::move(vjps[arg]));
}
return returned_vjps;
}
bool LayerNorm::is_equivalent(const Primitive& other) const {
const LayerNorm& a_other = static_cast<const LayerNorm&>(other);
return eps_ == a_other.eps_;
}
bool LayerNormVJP::is_equivalent(const Primitive& other) const {
const LayerNormVJP& a_other = static_cast<const LayerNormVJP&>(other);
return eps_ == a_other.eps_;
}
array rope(
const array& x,
int dims,
@@ -188,19 +328,16 @@ array rope(
float base,
float scale,
int offset,
StreamOrDevice s /* = {} */) {
bool forward,
StreamOrDevice s) {
if (x.ndim() < 3) {
std::ostringstream msg;
msg << "[rope] Input must have at least 3 dimensions but got input with "
<< x.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (traditional && x.shape(-1) != dims) {
throw std::invalid_argument(
"[rope] Does not support partial traditional application.");
}
auto fallback = [dims, traditional, base, scale, offset, s](
auto fallback = [dims, traditional, base, scale, offset, forward, s](
const std::vector<array>& inputs) {
auto& shape = inputs[0].shape();
int ndim = shape.size();
@@ -217,16 +354,39 @@ array rope(
auto coss = cos(theta, s);
auto sins = sin(theta, s);
if (traditional) {
auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s);
auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s);
auto apply_rope = [forward, s](
const array& x1,
const array& x2,
const array& coss,
const array& sins) {
std::vector<array> outs;
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
if (forward) {
outs.push_back(
subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
} else {
outs.push_back(add(multiply(x2, sins, s), multiply(x1, coss, s), s));
outs.push_back(
subtract(multiply(x2, coss, s), multiply(x1, sins, s), s));
}
return outs;
};
if (traditional) {
auto x1 =
slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
auto x2 =
slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
auto outs = apply_rope(x1, x2, coss, sins);
for (auto& o : outs) {
o = expand_dims(o, 3, s);
}
return std::vector<array>{reshape(concatenate(outs, 3, s), shape, s)};
auto out = concatenate(outs, 3, s);
if (dims < x.shape(-1)) {
out = reshape(out, {x.shape(0), x.shape(1), dims});
out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s);
}
return std::vector<array>{reshape(out, shape, s)};
} else {
auto out_s = x.shape();
out_s.back() = half_dims;
@@ -234,9 +394,7 @@ array rope(
out_s.back() = dims;
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
std::vector<array> outs;
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
auto outs = apply_rope(x1, x2, coss, sins);
if (dims < x.shape(-1)) {
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
}
@@ -249,18 +407,54 @@ array rope(
x.shape(),
x.dtype(),
std::make_shared<RoPE>(
stream, fallback, dims, traditional, base, scale, offset),
stream, fallback, dims, traditional, base, scale, offset, forward),
{x});
}
return fallback({x})[0];
}
array rope(
const array& x,
int dims,
bool traditional,
float base,
float scale,
int offset,
StreamOrDevice s /* = {} */) {
return rope(x, dims, traditional, base, scale, offset, true, s);
}
std::vector<array> RoPE::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto s = stream();
auto fallback = [dims = dims_,
traditional = traditional_,
base = base_,
scale = scale_,
offset = offset_,
forward = forward_,
s](std::vector<array> inputs) {
return std::vector<array>{
rope(inputs[0], dims, traditional, base, scale, offset, !forward, s)};
};
return {array(
cotangents[0].shape(),
cotangents[0].dtype(),
std::make_shared<RoPE>(
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
cotangents)};
}
bool RoPE::is_equivalent(const Primitive& other) const {
const RoPE& a_other = static_cast<const RoPE&>(other);
return (
dims_ == a_other.dims_ && base_ == a_other.base_ &&
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
offset_ == a_other.offset_);
offset_ == a_other.offset_ && forward_ == a_other.forward_);
}
/** Computes: O = softmax(Q @ K.T) @ V **/
@@ -319,7 +513,7 @@ array scaled_dot_product_attention(
}
auto final_type = result_type(queries, keys, values);
if (!is_floating_point(final_type) || is_complex(final_type)) {
if (!issubdtype(final_type, floating)) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Received unsupported type "
<< final_type << ".";
@@ -356,10 +550,7 @@ array scaled_dot_product_attention(
if (needs_mask) {
scores = add(scores, inputs[3], s);
}
scores = astype(
softmax(astype(scores, float32, s), std::vector<int>{-1}, s),
final_type,
s);
scores = softmax(scores, std::vector<int>{-1}, true, s);
auto out = matmul(scores, v, s);
if (n_repeats > 1) {
out = reshape(out, {B, n_q_heads, L, -1}, s);

View File

@@ -48,6 +48,12 @@ class RMSNorm : public Custom {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(RMSNorm)
bool is_equivalent(const Primitive& other) const override;
@@ -56,6 +62,29 @@ class RMSNorm : public Custom {
float eps_;
};
class RMSNormVJP : public Custom {
public:
RMSNormVJP(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
};
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(RMSNormVJP)
bool is_equivalent(const Primitive& other) const override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_;
};
class LayerNorm : public Custom {
public:
LayerNorm(
@@ -71,6 +100,12 @@ class LayerNorm : public Custom {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(LayerNorm)
bool is_equivalent(const Primitive& other) const override;
@@ -79,6 +114,29 @@ class LayerNorm : public Custom {
float eps_;
};
class LayerNormVJP : public Custom {
public:
LayerNormVJP(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
};
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(LayerNormVJP)
bool is_equivalent(const Primitive& other) const override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_;
};
class RoPE : public Custom {
public:
RoPE(
@@ -88,13 +146,15 @@ class RoPE : public Custom {
bool traditional,
float base,
float scale,
int offset)
int offset,
bool forward)
: Custom(stream, fallback),
dims_(dims),
traditional_(traditional),
base_(base),
scale_(scale),
offset_(offset){};
offset_(offset),
forward_(forward){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
@@ -103,6 +163,12 @@ class RoPE : public Custom {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(RoPE)
bool is_equivalent(const Primitive& other) const override;
@@ -113,6 +179,7 @@ class RoPE : public Custom {
float base_;
float scale_;
int offset_;
bool forward_;
};
class ScaledDotProductAttention : public Custom {
@@ -126,7 +193,7 @@ class ScaledDotProductAttention : public Custom {
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
outputs[0] = fallback_(inputs)[0];
throw std::runtime_error("NYI");
};
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)

View File

@@ -95,7 +95,7 @@ array fft_impl(
return array(
out_shape,
out_type,
std::make_unique<FFT>(to_stream(s), valid_axes, inverse, real),
std::make_shared<FFT>(to_stream(s), valid_axes, inverse, real),
{astype(in, in_type, s)});
}

View File

@@ -6,12 +6,10 @@
#include "array.h"
#include "device.h"
#include "stream.h"
#include "utils.h"
namespace mlx::core::fft {
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
/** Compute the n-dimensional Fourier Transform. */
array fftn(
const array& a,

View File

@@ -23,8 +23,7 @@ const std::string& NodeNamer::get_name(const array& x) {
letters.push_back('A' + (var_num - 1) % 26);
var_num = (var_num - 1) / 26;
}
std::string name(letters.rbegin(), letters.rend());
names.insert({x.id(), name});
names.emplace(x.id(), std::string(letters.rbegin(), letters.rend()));
return get_name(x);
}

View File

@@ -14,15 +14,15 @@ struct NodeNamer {
void print_graph(std::ostream& os, const std::vector<array>& outputs);
template <typename... Arrays>
void print_graph(std::ostream& os, Arrays... outputs) {
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
void print_graph(std::ostream& os, Arrays&&... outputs) {
print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
}
void export_to_dot(std::ostream& os, const std::vector<array>& outputs);
template <typename... Arrays>
void export_to_dot(std::ostream& os, Arrays... outputs) {
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
void export_to_dot(std::ostream& os, Arrays&&... outputs) {
export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...});
}

View File

@@ -23,13 +23,13 @@ using SafetensorsLoad = std::pair<
void save(std::shared_ptr<io::Writer> out_stream, array a);
/** Save array to file in .npy format */
void save(const std::string& file, array a);
void save(std::string file, array a);
/** Load array from reader in .npy format */
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
/** Load array from file in .npy format */
array load(const std::string& file, StreamOrDevice s = {});
array load(std::string file, StreamOrDevice s = {});
/** Load array map from .safetensors file format */
SafetensorsLoad load_safetensors(
@@ -44,7 +44,7 @@ void save_safetensors(
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string> metadata = {});
void save_safetensors(
const std::string& file,
std::string file,
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string> metadata = {});

View File

@@ -206,7 +206,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
std::unordered_map<std::string, array> array_map;
gguf_tensor tensor;
auto check_insert = [](auto inserted) {
auto check_insert = [](const auto& inserted) {
if (!inserted.second) {
std::ostringstream msg;
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
@@ -216,6 +216,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
};
while (gguf_get_tensor(ctx, &tensor)) {
std::string name(tensor.name, tensor.namelen);
if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 ||
tensor.type == GGUF_TYPE_Q8_0) {
gguf_load_quantized(array_map, tensor);
@@ -224,14 +225,14 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
const auto& [data, dtype] = extract_tensor_data(&tensor);
array loaded_array = array(data, get_shape(tensor), dtype);
array_map.insert({name, loaded_array});
check_insert(array_map.insert({name, loaded_array}));
}
}
return array_map;
}
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
gguf_ctx* ctx = gguf_open(file.c_str());
gguf_ctx* ctx = gguf_open(file.data());
if (!ctx) {
throw std::runtime_error("[load_gguf] gguf_init failed");
}

View File

@@ -105,7 +105,8 @@ void gguf_load_quantized(
weights_per_byte = 1;
}
std::string name = std::string(tensor.name, tensor.namelen);
std::string name(tensor.name, tensor.namelen);
std::vector<int> shape = get_shape(tensor);
const uint64_t weights_per_block = 32;
if (shape[shape.size() - 1] % weights_per_block != 0) {
@@ -136,9 +137,9 @@ void gguf_load_quantized(
extract_q8_0_data(tensor, weights, scales, biases);
}
a.insert({name, weights});
a.emplace(name, std::move(weights));
auto check_insert = [](auto inserted) {
auto check_insert = [](const auto& inserted) {
if (!inserted.second) {
std::ostringstream msg;
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
@@ -147,11 +148,11 @@ void gguf_load_quantized(
}
};
const std::string weight_suffix = ".weight";
constexpr std::string_view weight_suffix = ".weight";
const std::string name_prefix =
name.substr(0, name.length() - weight_suffix.length());
check_insert(a.insert({name_prefix + ".scales", scales}));
check_insert(a.insert({name_prefix + ".biases", biases}));
check_insert(a.emplace(name_prefix + ".scales", std::move(scales)));
check_insert(a.emplace(name_prefix + ".biases", std::move(biases)));
}
} // namespace mlx::core

View File

@@ -26,16 +26,6 @@ constexpr uint8_t MAGIC[] = {
0x59,
};
inline bool is_big_endian_() {
union ByteOrder {
int32_t i;
uint8_t c[4];
};
ByteOrder b = {0x01234567};
return b.c[0] == 0x01;
}
} // namespace
/** Save array to out stream in .npy format */
@@ -73,8 +63,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
std::string fortran_order = a.flags().col_contiguous ? "True" : "False";
std::ostringstream header;
header << "{'descr': '" << dtype_to_array_protocol(a.dtype()) << "',"
<< " 'fortran_order': " << fortran_order << ","
<< " 'shape': (";
<< " 'fortran_order': " << fortran_order << "," << " 'shape': (";
for (auto i : a.shape()) {
header << i << ", ";
}
@@ -94,7 +83,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
uint16_t v1_header_len = header.tellp();
const char* len_bytes = reinterpret_cast<const char*>(&v1_header_len);
if (!is_big_endian_()) {
if (!is_big_endian()) {
magic_ver_len.write(len_bytes, 2);
} else {
magic_ver_len.write(len_bytes + 1, 1);
@@ -106,7 +95,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
uint32_t v2_header_len = header.tellp();
const char* len_bytes = reinterpret_cast<const char*>(&v2_header_len);
if (!is_big_endian_()) {
if (!is_big_endian()) {
magic_ver_len.write(len_bytes, 4);
} else {
magic_ver_len.write(len_bytes + 3, 1);
@@ -124,16 +113,13 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
}
/** Save array to file in .npy format */
void save(const std::string& file_, array a) {
// Open and check file
std::string file = file_;
void save(std::string file, array a) {
// Add .npy to file name if it is not there
if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy")
file += ".npy";
// Serialize array
save(std::make_shared<io::FileWriter>(file), a);
save(std::make_shared<io::FileWriter>(std::move(file)), a);
}
/** Load array from reader in .npy format */
@@ -219,7 +205,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
// Build primitive
size_t offset = 8 + header_len_size + header.length();
bool swap_endianness = read_is_big_endian != is_big_endian_();
bool swap_endianness = read_is_big_endian != is_big_endian();
if (col_contiguous) {
std::reverse(shape.begin(), shape.end());
@@ -227,7 +213,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
auto loaded_array = array(
shape,
dtype,
std::make_unique<Load>(to_stream(s), in_stream, offset, swap_endianness),
std::make_shared<Load>(to_stream(s), in_stream, offset, swap_endianness),
std::vector<array>{});
if (col_contiguous) {
loaded_array = transpose(loaded_array, s);
@@ -237,8 +223,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
}
/** Load array from file in .npy format */
array load(const std::string& file, StreamOrDevice s) {
return load(std::make_shared<io::FileReader>(file), s);
array load(std::string file, StreamOrDevice s) {
return load(std::make_shared<io::FileReader>(std::move(file)), s);
}
} // namespace mlx::core

View File

@@ -60,7 +60,7 @@ std::string dtype_to_safetensor_str(Dtype t) {
}
}
Dtype dtype_from_safetensor_str(std::string str) {
Dtype dtype_from_safetensor_str(std::string_view str) {
if (str == ST_F32) {
return float32;
} else if (str == ST_F16) {
@@ -88,7 +88,8 @@ Dtype dtype_from_safetensor_str(std::string str) {
} else if (str == ST_C64) {
return complex64;
} else {
throw std::runtime_error("[safetensor] unsupported dtype " + str);
throw std::runtime_error(
"[safetensor] unsupported dtype " + std::string(str));
}
}
@@ -129,14 +130,14 @@ SafetensorsLoad load_safetensors(
}
continue;
}
std::string dtype = item.value().at("dtype");
std::vector<int> shape = item.value().at("shape");
std::vector<size_t> data_offsets = item.value().at("data_offsets");
const std::string& dtype = item.value().at("dtype");
const std::vector<int>& shape = item.value().at("shape");
const std::vector<size_t>& data_offsets = item.value().at("data_offsets");
Dtype type = dtype_from_safetensor_str(dtype);
auto loaded_array = array(
shape,
type,
std::make_unique<Load>(
std::make_shared<Load>(
to_stream(s), in_stream, offset + data_offsets.at(0), false),
std::vector<array>{});
res.insert({item.key(), loaded_array});
@@ -207,19 +208,17 @@ void save_safetensors(
}
void save_safetensors(
const std::string& file_,
std::string file,
std::unordered_map<std::string, array> a,
std::unordered_map<std::string, std::string> metadata /* = {} */) {
// Open and check file
std::string file = file_;
// Add .safetensors to file name if it is not there
if (file.length() < 12 ||
file.substr(file.length() - 12, 12) != ".safetensors")
file += ".safetensors";
// Serialize array
save_safetensors(std::make_shared<io::FileWriter>(file), a, metadata);
save_safetensors(
std::make_shared<io::FileWriter>(std::move(file)), a, metadata);
}
} // namespace mlx::core

View File

@@ -11,7 +11,7 @@
namespace mlx::core::linalg {
Dtype at_least_float(const Dtype& d) {
return is_floating_point(d) ? d : promote_types(d, float32);
return issubdtype(d, inexact) ? d : promote_types(d, float32);
}
inline array l2_norm(
@@ -19,7 +19,7 @@ inline array l2_norm(
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
if (is_complex(a.dtype())) {
if (issubdtype(a.dtype(), complexfloating)) {
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
} else {
return sqrt(sum(square(a, s), axis, keepdims, s), s);

View File

@@ -47,7 +47,7 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
}
Dtype at_least_float(const Dtype& d) {
return is_floating_point(d) ? d : promote_types(d, float32);
return issubdtype(d, inexact) ? d : promote_types(d, float32);
}
} // namespace
@@ -158,50 +158,64 @@ array linspace(
to_stream(s));
}
array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) {
array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {
if (dtype == a.dtype()) {
return a;
return std::move(a);
}
auto copied_shape = a.shape(); // |a| will be moved
return array(
a.shape(), dtype, std::make_shared<AsType>(to_stream(s), dtype), {a});
std::move(copied_shape),
dtype,
std::make_shared<AsType>(to_stream(s), dtype),
{std::move(a)});
}
array as_strided(
const array& a,
array a,
std::vector<int> shape,
std::vector<size_t> strides,
size_t offset,
StreamOrDevice s /* = {} */) {
// Force the input array to be contiguous
auto x = reshape(a, {-1}, s);
auto copied_shape = shape; // |shape| will be moved
auto dtype = a.dtype(); // |a| will be moved
return array(
shape,
a.dtype(),
std::make_shared<AsStrided>(to_stream(s), shape, strides, offset),
{x});
std::move(copied_shape),
dtype,
std::make_shared<AsStrided>(
to_stream(s), std::move(shape), std::move(strides), offset),
// Force the input array to be contiguous.
{reshape(std::move(a), {-1}, s)});
}
array copy(const array& a, StreamOrDevice s /* = {} */) {
return array(a.shape(), a.dtype(), std::make_shared<Copy>(to_stream(s)), {a});
array copy(array a, StreamOrDevice s /* = {} */) {
auto copied_shape = a.shape(); // |a| will be moved
auto dtype = a.dtype();
return array(
std::move(copied_shape),
dtype,
std::make_shared<Copy>(to_stream(s)),
{std::move(a)});
}
array full(
const std::vector<int>& shape,
const array& vals,
std::vector<int> shape,
array vals,
Dtype dtype,
StreamOrDevice s /* = {} */) {
if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {
if (std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; })) {
throw std::invalid_argument("[full] Negative dimensions not allowed.");
}
auto in = broadcast_to(astype(vals, dtype, s), shape, s);
return array(shape, dtype, std::make_shared<Full>(to_stream(s)), {in});
auto copied_shape = shape; // |shape| will be moved
return array(
std::move(copied_shape),
dtype,
std::make_shared<Full>(to_stream(s)),
{broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)});
}
array full(
const std::vector<int>& shape,
const array& vals,
StreamOrDevice s /* = {} */) {
return full(shape, vals, vals.dtype(), to_stream(s));
array full(std::vector<int> shape, array vals, StreamOrDevice s /* = {} */) {
auto dtype = vals.dtype(); // |vals| will be moved
return full(std::move(shape), std::move(vals), dtype, to_stream(s));
}
array zeros(
@@ -682,6 +696,41 @@ split(const array& a, int num_splits, StreamOrDevice s /* = {} */) {
return split(a, num_splits, 0, to_stream(s));
}
std::vector<array> meshgrid(
const std::vector<array>& arrays,
bool sparse /* = false */,
std::string indexing /* = "xy" */,
StreamOrDevice s /* = {} */) {
if (indexing != "xy" && indexing != "ij") {
throw std::invalid_argument(
"[meshgrid] Invalid indexing value. Valid values are 'xy' and 'ij'.");
}
auto ndim = arrays.size();
std::vector<array> outputs;
for (int i = 0; i < ndim; ++i) {
std::vector<int> shape(ndim, 1);
shape[i] = -1;
outputs.push_back(reshape(arrays[i], std::move(shape), s));
}
if (indexing == "xy" and ndim > 1) {
std::vector<int> shape(ndim, 1);
shape[1] = arrays[0].size();
outputs[0] = reshape(arrays[0], shape, s);
shape[1] = 1;
shape[0] = arrays[1].size();
outputs[1] = reshape(arrays[1], std::move(shape), s);
}
if (!sparse) {
outputs = broadcast_arrays(outputs, s);
}
return outputs;
}
array clip(
const array& a,
const std::optional<array>& a_min,
@@ -883,15 +932,15 @@ array pad(
if (low_pad_size[i] < 0) {
std::ostringstream msg;
msg << "Invalid low padding size (" << low_pad_size[i]
<< ") passed to pad"
<< " for axis " << i << ". Padding sizes must be non-negative";
<< ") passed to pad" << " for axis " << i
<< ". Padding sizes must be non-negative";
throw std::invalid_argument(msg.str());
}
if (high_pad_size[i] < 0) {
std::ostringstream msg;
msg << "Invalid high padding size (" << high_pad_size[i]
<< ") passed to pad"
<< " for axis " << i << ". Padding sizes must be non-negative";
<< ") passed to pad" << " for axis " << i
<< ". Padding sizes must be non-negative";
throw std::invalid_argument(msg.str());
}
@@ -1001,19 +1050,30 @@ array transpose(
for (auto& ax : axes) {
ax = ax < 0 ? ax + a.ndim() : ax;
}
std::set dims(axes.begin(), axes.end());
if (dims.size() != axes.size()) {
throw std::invalid_argument("Repeat axes not allowed in transpose.");
if (axes.size() != a.ndim()) {
std::ostringstream msg;
msg << "[transpose] Recived " << axes.size() << " axes for array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (dims.size() != a.ndim() ||
a.ndim() > 0 &&
(*dims.begin() != 0 || *dims.rbegin() != (a.ndim() - 1))) {
throw std::invalid_argument("Transpose axes don't match array dimensions.");
// Check in bounds and for duplicates
std::vector<int> shape(axes.size(), 0);
for (auto& ax : axes) {
if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg;
msg << "[transpose] Invalid axis (" << ax << ") for array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (shape[ax] != 0) {
throw std::invalid_argument("[transpose] Repeat axes not allowed.");
}
shape[ax] = 1;
}
std::vector<int> shape;
shape.reserve(axes.size());
for (auto ax : axes) {
shape.push_back(a.shape()[ax]);
for (int i = 0; i < axes.size(); ++i) {
shape[i] = a.shape()[axes[i]];
}
return array(
std::move(shape),
@@ -1140,7 +1200,7 @@ array array_equal(
return array(false);
} else {
auto dtype = promote_types(a.dtype(), b.dtype());
equal_nan &= is_floating_point(dtype);
equal_nan &= issubdtype(dtype, inexact);
return all(
array(
a.shape(),
@@ -1153,7 +1213,7 @@ array array_equal(
}
array isnan(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
return full(a.shape(), false, bool_, s);
}
return not_equal(a, a, s);
@@ -1164,14 +1224,14 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
}
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
return full(a.shape(), false, bool_, s);
}
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
}
array isneginf(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
return full(a.shape(), false, bool_, s);
}
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
@@ -1416,6 +1476,34 @@ array var(
return var(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
}
array std(
const array& a,
bool keepdims,
int ddof /* = 0*/,
StreamOrDevice s /* = {}*/) {
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
return std(a, axes, keepdims, ddof, to_stream(s));
}
array std(
const array& a,
const std::vector<int>& axes,
bool keepdims /* = false */,
int ddof /* = 0*/,
StreamOrDevice s /* = {}*/) {
return sqrt(var(a, axes, keepdims, ddof, s), s);
}
array std(
const array& a,
int axis,
bool keepdims /* = false */,
int ddof /* = 0*/,
StreamOrDevice s /* = {} */) {
return std(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
}
array prod(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
@@ -1929,7 +2017,7 @@ array floor_divide(
const array& b,
StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
if (is_floating_point(dtype)) {
if (issubdtype(dtype, inexact)) {
return floor(divide(a, b, s), s);
}
@@ -1957,7 +2045,7 @@ array operator%(const array& a, const array& b) {
std::vector<array>
divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
if (is_complex(dtype)) {
if (issubdtype(dtype, complexfloating)) {
throw std::invalid_argument("[divmod] Complex type not supported.");
}
auto inputs =
@@ -2019,6 +2107,13 @@ array exp(const array& a, StreamOrDevice s /* = {} */) {
return array(a.shape(), dtype, std::make_shared<Exp>(to_stream(s)), {input});
}
array expm1(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(
a.shape(), dtype, std::make_shared<Expm1>(to_stream(s)), {input});
}
array sin(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
@@ -2220,7 +2315,7 @@ array matmul(
}
// Type promotion
auto out_type = promote_types(a.dtype(), b.dtype());
if (!is_floating_point(out_type) || is_complex(out_type)) {
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[matmul] Only real floating point types are supported but "
<< a.dtype() << " and " << b.dtype() << " were provided which results"
@@ -2330,7 +2425,7 @@ array gather(
// Promote indices to the same type
auto dtype = result_type(indices);
if (!is_integral(dtype)) {
if (issubdtype(dtype, inexact)) {
throw std::invalid_argument(
"[gather] Got indices with invalid dtype. Indices must be integral.");
}
@@ -2413,8 +2508,8 @@ array take_along_axis(
StreamOrDevice s /* = {} */) {
if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {
std::ostringstream msg;
msg << "[take_along_axis] Received invalid axis "
<< " for array with " << a.ndim() << " dimensions.";
msg << "[take_along_axis] Received invalid axis " << " for array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
@@ -2521,7 +2616,7 @@ array scatter(
// Promote indices to the same type
auto dtype = result_type(indices);
if (!is_integral(dtype)) {
if (issubdtype(dtype, inexact)) {
throw std::invalid_argument(
"[scatter] Got indices with invalid dtype. Indices must be integral.");
}
@@ -2605,25 +2700,34 @@ array rsqrt(const array& a, StreamOrDevice s /* = {} */) {
array softmax(
const array& a,
const std::vector<int>& axes,
bool precise /* = false */,
StreamOrDevice s /* = {}*/) {
if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype());
return array(
a.shape(),
dtype,
std::make_shared<Softmax>(to_stream(s)),
std::make_shared<Softmax>(to_stream(s), precise),
{astype(a, dtype, s)});
} else {
auto a_max = stop_gradient(max(a, axes, /*keepdims = */ true, s), s);
auto ex = exp(subtract(a, a_max, s), s);
return divide(ex, sum(ex, axes, /*keepdims = */ true, s), s);
auto in = a;
if (precise) {
in = astype(a, float32, s);
}
auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s);
auto ex = exp(subtract(in, a_max, s), s);
return astype(
divide(ex, sum(ex, axes, /*keepdims = */ true, s), s), a.dtype(), s);
}
}
array softmax(const array& a, StreamOrDevice s /* = {}*/) {
array softmax(
const array& a,
bool precise /* = false */,
StreamOrDevice s /* = {}*/) {
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
return softmax(a, axes, s);
return softmax(a, axes, precise, s);
}
array power(const array& a, const array& b, StreamOrDevice s /* = {} */) {
@@ -2800,15 +2904,15 @@ inline std::vector<int> conv_out_shape(
if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
std::ostringstream msg;
msg << "[conv] Padding sizes must be non-negative."
<< " Got padding " << pads_lo << " | " << pads_hi << ".";
msg << "[conv] Padding sizes must be non-negative." << " Got padding "
<< pads_lo << " | " << pads_hi << ".";
throw std::invalid_argument(msg.str());
}
if (strides[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Stride sizes must be positive."
<< " Got strides " << strides << ".";
msg << "[conv] Stride sizes must be positive." << " Got strides "
<< strides << ".";
throw std::invalid_argument(msg.str());
}
@@ -2834,7 +2938,7 @@ inline std::vector<int> conv_out_shape(
}
inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
if (!is_floating_point(in.dtype()) && kindof(in.dtype()) != Dtype::Kind::c) {
if (!issubdtype(in.dtype(), floating)) {
std::ostringstream msg;
msg << "[conv] Invalid input array with type " << in.dtype() << "."
<< " Convolution currently only supports floating point types";
@@ -2844,8 +2948,7 @@ inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
if (in.ndim() != n_dim + 2) {
std::ostringstream msg;
msg << "[conv] Invalid input array with " << in.ndim() << " dimensions for "
<< n_dim << "D convolution."
<< " Expected an array with " << n_dim + 2
<< n_dim << "D convolution." << " Expected an array with " << n_dim + 2
<< " dimensions following the format [N, ..., C_in].";
throw std::invalid_argument(msg.str());
}
@@ -2971,6 +3074,35 @@ array conv_general(
input_dilation = std::vector<int>(spatial_dims, input_dilation_int);
}
// Check for negative padding
bool has_neg_padding = false;
for (auto& pd : padding_lo) {
has_neg_padding = (pd < 0);
}
for (auto& pd : padding_hi) {
has_neg_padding = (pd < 0);
}
// Handle negative padding
if (has_neg_padding) {
std::vector<int> starts(in.ndim(), 0);
std::vector<int> stops = in.shape();
for (int i = 0; i < spatial_dims; i++) {
if (padding_lo[i] < 0) {
starts[i + 1] -= padding_lo[i];
padding_lo[i] = 0;
}
if (padding_hi[i] < 0) {
stops[i + 1] += padding_hi[i];
padding_hi[i] = 0;
}
}
in = slice(in, std::move(starts), std::move(stops), s);
}
// Get output shapes
std::vector<int> out_shape = conv_out_shape(
in.shape(),
@@ -3005,7 +3137,6 @@ array quantized_matmul(
int bits /* = 4 */,
StreamOrDevice s /* = {} */) {
array x = in_x;
if (w.dtype() != uint32) {
std::ostringstream msg;
msg << "[quantized_matmul] The weight matrix should be uint32 "
@@ -3022,12 +3153,6 @@ array quantized_matmul(
// Keep x's batch dimensions to reshape it back after the matmul
auto original_shape = x.shape();
int x_inner_dims = original_shape.back();
original_shape.pop_back();
// Reshape x into a matrix if it isn't already one
if (x.ndim() != 2) {
x = reshape(x, {-1, x_inner_dims}, s);
}
if (scales.ndim() != 2 || scales.shape() != biases.shape()) {
std::ostringstream msg;
@@ -3062,7 +3187,7 @@ array quantized_matmul(
}
auto dtype = result_type(x, scales, biases);
if (!is_floating_point(dtype) || is_complex(dtype)) {
if (!issubdtype(dtype, floating)) {
std::ostringstream msg;
msg << "[quantized_matmul] Only real floating types are supported but "
<< "the passed types where x.dtype() == " << x.dtype()
@@ -3070,9 +3195,10 @@ array quantized_matmul(
<< " and biases.dtype() == " << biases.dtype();
throw std::invalid_argument(msg.str());
}
auto out = array(
{x.shape(0), w_outer_dims},
std::vector<array> inputs;
original_shape.back() = w_outer_dims;
return array(
std::move(original_shape),
dtype,
std::make_shared<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose),
@@ -3080,14 +3206,6 @@ array quantized_matmul(
w,
astype(scales, dtype, s),
astype(biases, dtype, s)});
// If needed reshape x to the original batch shape
if (original_shape.size() != 1) {
original_shape.push_back(w_outer_dims);
out = reshape(out, std::move(original_shape), s);
}
return out;
}
std::tuple<array, array, array> quantize(
@@ -3117,8 +3235,7 @@ std::tuple<array, array, array> quantize(
std::ostringstream msg;
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
<< "the quantization group size " << group_size
<< ". However the provided "
<< " matrix has shape " << w.shape();
<< ". However the provided " << " matrix has shape " << w.shape();
throw std::invalid_argument(msg.str());
}
@@ -3364,7 +3481,7 @@ array addmm(
// Type promotion
auto out_type = result_type(a, b, c);
if (!is_floating_point(out_type) || is_complex(out_type)) {
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[addmm] Only real floating point types are supported but "
<< c.dtype() << ", " << a.dtype() << " and " << b.dtype()

View File

@@ -41,40 +41,33 @@ array linspace(
StreamOrDevice s = {});
/** Convert an array to the given data type. */
array astype(const array& a, Dtype dtype, StreamOrDevice s = {});
array astype(array a, Dtype dtype, StreamOrDevice s = {});
/** Create a view of an array with the given shape and strides. */
array as_strided(
const array& a,
array a,
std::vector<int> shape,
std::vector<size_t> strides,
size_t offset,
StreamOrDevice s = {});
/** Copy another array. */
array copy(const array& a, StreamOrDevice s = {});
array copy(array a, StreamOrDevice s = {});
/** Fill an array of the given shape with the given value(s). */
array full(
const std::vector<int>& shape,
const array& vals,
std::vector<int> shape,
array vals,
Dtype dtype,
StreamOrDevice s = {});
array full(
const std::vector<int>& shape,
const array& vals,
StreamOrDevice s = {});
array full(std::vector<int> shape, array vals, StreamOrDevice s = {});
template <typename T>
array full(
const std::vector<int>& shape,
T val,
Dtype dtype,
StreamOrDevice s = {}) {
return full(shape, array(val, dtype), to_stream(s));
array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {}) {
return full(std::move(shape), array(val, dtype), to_stream(s));
}
template <typename T>
array full(const std::vector<int>& shape, T val, StreamOrDevice s = {}) {
return full(shape, array(val), to_stream(s));
array full(std::vector<int> shape, T val, StreamOrDevice s = {}) {
return full(std::move(shape), array(val), to_stream(s));
}
/** Fill an array of the given shape with zeros. */
@@ -204,6 +197,13 @@ std::vector<array> split(
std::vector<array>
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
/** A vector of coordinate arrays from coordinate vectors. */
std::vector<array> meshgrid(
const std::vector<array>& arrays,
bool sparse = false,
std::string indexing = "xy",
StreamOrDevice s = {});
/**
* Clip (limit) the values in an array.
*/
@@ -514,13 +514,14 @@ array mean(
bool keepdims = false,
StreamOrDevice s = {});
/** Computes the mean of the elements of an array. */
/** Computes the variance of the elements of an array. */
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
inline array var(const array& a, StreamOrDevice s = {}) {
return var(a, false, 0, to_stream(s));
}
/** Computes the var of the elements of an array along the given axes */
/** Computes the variance of the elements of an array along the given
* axes */
array var(
const array& a,
const std::vector<int>& axes,
@@ -528,7 +529,8 @@ array var(
int ddof = 0,
StreamOrDevice s = {});
/** Computes the var of the elements of an array along the given axis */
/** Computes the variance of the elements of an array along the given
* axis */
array var(
const array& a,
int axis,
@@ -536,6 +538,30 @@ array var(
int ddof = 0,
StreamOrDevice s = {});
/** Computes the standard deviation of the elements of an array. */
array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
inline array std(const array& a, StreamOrDevice s = {}) {
return std(a, false, 0, to_stream(s));
}
/** Computes the standard deviatoin of the elements of an array along the given
* axes */
array std(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {});
/** Computes the standard deviation of the elements of an array along the given
* axis */
array std(
const array& a,
int axis,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {});
/** The product of all elements of the array. */
array prod(const array& a, bool keepdims, StreamOrDevice s = {});
inline array prod(const array& a, StreamOrDevice s = {}) {
@@ -849,6 +875,9 @@ array erf(const array& a, StreamOrDevice s = {});
/** Computes the inverse error function of the elements of an array. */
array erfinv(const array& a, StreamOrDevice s = {});
/** Computes the expm1 function of the elements of an array. */
array expm1(const array& a, StreamOrDevice s = {});
/** Stop the flow of gradients. */
array stop_gradient(const array& a, StreamOrDevice s = {});
@@ -983,14 +1012,16 @@ array rsqrt(const array& a, StreamOrDevice s = {});
array softmax(
const array& a,
const std::vector<int>& axes,
bool precise = false,
StreamOrDevice s = {});
/** Softmax of an array. */
array softmax(const array& a, StreamOrDevice s = {});
array softmax(const array& a, bool precise = false, StreamOrDevice s = {});
/** Softmax of an array. */
inline array softmax(const array& a, int axis, StreamOrDevice s = {}) {
return softmax(a, std::vector<int>{axis}, s);
inline array
softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
return softmax(a, std::vector<int>{axis}, precise, s);
}
/** Raise elements of a to the power of b element-wise */

View File

@@ -1239,6 +1239,34 @@ std::pair<std::vector<array>, std::vector<int>> Exp::vmap(
return {{exp(inputs[0], stream())}, axes};
}
std::vector<array> Expm1::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
return {multiply(
cotangents[0],
add(outputs[0], array(1.0f, outputs[0].dtype()), stream()),
stream())};
}
std::vector<array> Expm1::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return {multiply(tangents[0], exp(primals[0], stream()), stream())};
}
std::pair<std::vector<array>, std::vector<int>> Expm1::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{expm1(inputs[0], stream())}, axes};
}
bool FFT::is_equivalent(const Primitive& other) const {
const FFT& r_other = static_cast<const FFT&>(other);
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&
@@ -1267,7 +1295,7 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
{array(
out_shape,
real_ && inverse_ ? float32 : complex64,
std::make_unique<FFT>(stream(), fft_axes, inverse_, real_),
std::make_shared<FFT>(stream(), fft_axes, inverse_, real_),
{in})},
{ax}};
}
@@ -1377,7 +1405,7 @@ std::pair<std::vector<array>, std::vector<int>> Full::vmap(
assert(axes.size() == 1);
auto& in = inputs[0];
auto out =
array(in.shape(), in.dtype(), std::make_unique<Full>(stream()), {in});
array(in.shape(), in.dtype(), std::make_shared<Full>(stream()), {in});
return {{out}, axes};
}
@@ -1604,7 +1632,7 @@ std::pair<std::vector<array>, std::vector<int>> Log::vmap(
{array(
in.shape(),
in.dtype(),
std::make_unique<Log>(stream(), base_),
std::make_shared<Log>(stream(), base_),
{in})},
axes};
}
@@ -2259,7 +2287,7 @@ std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap(
auto out = array(
shape,
get_dtype(),
std::make_unique<RandomBits>(stream(), shape, width_),
std::make_shared<RandomBits>(stream(), shape, width_),
{key});
return {{out}, {kax}};
}
@@ -2493,7 +2521,7 @@ std::pair<std::vector<array>, std::vector<int>> Scan::vmap(
{array(
in.shape(),
out_dtype,
std::make_unique<Scan>(
std::make_shared<Scan>(
stream(), reduce_type_, axis_ + axis_left, reverse_, inclusive_),
{in})},
axes};
@@ -2975,7 +3003,7 @@ std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
} else {
softmax_axes.push_back(-2);
}
return {{softmax(inputs[0], softmax_axes, stream())}, axes};
return {{softmax(inputs[0], softmax_axes, precise_, stream())}, axes};
}
std::vector<array> Softmax::vjp(
@@ -2998,13 +3026,18 @@ std::vector<array> Softmax::jvp(
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(tangents.size() == 1);
auto s = softmax(primals[0], std::vector<int>{-1}, stream());
auto s = softmax(primals[0], std::vector<int>{-1}, precise_, stream());
auto sv = multiply(s, tangents[0], stream());
return {subtract(
sv,
multiply(s, sum(sv, std::vector<int>{-1}, true, stream()), stream()))};
}
bool Softmax::is_equivalent(const Primitive& other) const {
const Softmax& s_other = static_cast<const Softmax&>(other);
return precise_ == s_other.precise_;
}
std::pair<std::vector<array>, std::vector<int>> Sort::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@@ -3303,7 +3336,7 @@ std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
array out = array(
std::vector<int>{},
dtype_,
std::make_unique<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
std::make_shared<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
inputs);
return {{out}, {-1}};

Some files were not shown because too many files have changed in this diff Show More