mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 21:37:50 +08:00
Compare commits
54 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
d07e295c62 | ||
![]() |
dce4bd74a4 | ||
![]() |
ffff671273 | ||
![]() |
12d4507ee3 | ||
![]() |
8580d997ff | ||
![]() |
061cf9a4ce | ||
![]() |
99abb9eff4 | ||
![]() |
fffe072028 | ||
![]() |
a1a31eed27 | ||
![]() |
ae812350f9 | ||
![]() |
b63ef10a7f | ||
![]() |
42afe27e12 | ||
![]() |
76e63212ff | ||
![]() |
aac2f9fb61 | ||
![]() |
bddf23f175 | ||
![]() |
039da779d1 | ||
![]() |
d88d2124b5 | ||
![]() |
e142aaf8a1 | ||
![]() |
0caf35f4b8 | ||
![]() |
3fc993f82d | ||
![]() |
741eb28443 | ||
![]() |
1a87dc5ea8 | ||
![]() |
2427fa171e | ||
![]() |
639e06e1f3 | ||
![]() |
02fedbf1da | ||
![]() |
110d9b149d | ||
![]() |
9cbff5ec1d | ||
![]() |
433c0206b0 | ||
![]() |
8915901966 | ||
![]() |
f48bc496c7 | ||
![]() |
913b19329c | ||
![]() |
d8cb3128f6 | ||
![]() |
5f9ba3019f | ||
![]() |
46caf0bef0 | ||
![]() |
45f636e759 | ||
![]() |
a7b404ff53 | ||
![]() |
c4fd0e5ede | ||
![]() |
bab5386306 | ||
![]() |
aca7584635 | ||
![]() |
d611251502 | ||
![]() |
f30b659291 | ||
![]() |
90dfa43ff1 | ||
![]() |
dc175f08d3 | ||
![]() |
29221fa238 | ||
![]() |
a789685c63 | ||
![]() |
240d10699c | ||
![]() |
925014b661 | ||
![]() |
5611e1a95e | ||
![]() |
570f2bf29e | ||
![]() |
9948eddf11 | ||
![]() |
a3ee03da01 | ||
![]() |
28fcd2b519 | ||
![]() |
8e686764ac | ||
![]() |
479051ce1c |
@@ -1,11 +1,11 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v17.0.6
|
||||
rev: v18.1.3
|
||||
hooks:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 24.2.0
|
||||
rev: 24.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
|
@@ -15,6 +15,8 @@ MLX was developed with contributions from the following individuals:
|
||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
</a>
|
||||
|
@@ -15,31 +15,33 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
||||
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.8.1)
|
||||
set(MLX_VERSION 0.10.0)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
|
||||
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
|
||||
message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
|
||||
|
||||
set(MLX_BUILD_ARM OFF)
|
||||
|
||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
|
||||
message(FATAL_ERROR
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, check the build"
|
||||
" documentation for possible fixes: "
|
||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
|
||||
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||
message(WARNING
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, "
|
||||
" make sure you are building for arm64.")
|
||||
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||
if(NOT MLX_ENABLE_X64_MAC)
|
||||
message(FATAL_ERROR
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, check the build"
|
||||
" documentation for possible fixes: "
|
||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
|
||||
else()
|
||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||
endif()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
set(MLX_BUILD_ARM ON)
|
||||
endif()
|
||||
|
||||
@@ -64,8 +66,14 @@ endif()
|
||||
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
set(MLX_METAL_DEBUG OFF)
|
||||
elseif (MLX_BUILD_METAL)
|
||||
message(STATUS "Building METAL sources")
|
||||
|
||||
if (MLX_METAL_DEBUG)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
endif()
|
||||
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_VERSION
|
||||
@@ -108,7 +116,27 @@ if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
#set(BLA_VENDOR Generic)
|
||||
if(${CMAKE_HOST_APPLE})
|
||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||
# openblas instead.
|
||||
set(BLA_VENDOR OpenBLAS)
|
||||
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
||||
endif()
|
||||
# Search and link with lapack.
|
||||
find_package(LAPACK REQUIRED)
|
||||
if (NOT LAPACK_FOUND)
|
||||
message(FATAL_ERROR "Must have LAPACK installed")
|
||||
endif()
|
||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
/usr/local/opt/openblas/include)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old version
|
||||
# of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
if (NOT BLAS_FOUND)
|
||||
message(FATAL_ERROR "Must have BLAS installed")
|
||||
@@ -122,17 +150,6 @@ else()
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
||||
find_package(LAPACK REQUIRED)
|
||||
if (NOT LAPACK_FOUND)
|
||||
message(FATAL_ERROR "Must have LAPACK installed")
|
||||
endif()
|
||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||
/usr/include
|
||||
/usr/local/include)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
|
@@ -17,14 +17,13 @@
|
||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||
<< std::endl;
|
||||
|
||||
#define TIMEM(MSG, FUNC, ...) \
|
||||
std::cout << "Timing " \
|
||||
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
|
||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||
<< std::endl;
|
||||
#define TIMEM(MSG, FUNC, ...) \
|
||||
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
|
||||
<< std::flush << std::setprecision(5) \
|
||||
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
|
||||
|
||||
template <typename F, typename... Args>
|
||||
double time_fn(F fn, Args... args) {
|
||||
double time_fn(F fn, Args&&... args) {
|
||||
// warmup
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
eval(fn(std::forward<Args>(args)...));
|
||||
|
41
benchmarks/python/layer_norm_bench.py
Normal file
41
benchmarks/python/layer_norm_bench.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def layer_norm(x, w, b, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
mu = mx.mean(x, -1, keepdims=True)
|
||||
v = mx.var(x, -1, keepdims=True)
|
||||
return (x - mu) * mx.rsqrt(v + eps) * w + b
|
||||
|
||||
|
||||
def time_layer_norm():
|
||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x, w, b):
|
||||
gx, gw, gb = x, w, b
|
||||
for _ in range(32):
|
||||
gx, gw, gb = g(gx, gw, gb, y)
|
||||
return gx, gw, gb
|
||||
|
||||
time_fn(layer_norm_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_layer_norm()
|
39
benchmarks/python/rms_norm_bench.py
Normal file
39
benchmarks/python/rms_norm_bench.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def rms_norm(x, w, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return (x * n).astype(ot) * w
|
||||
|
||||
|
||||
def time_rms_norm():
|
||||
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
|
||||
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0, 1))
|
||||
g2 = mx.grad(f2, argnums=(0, 1))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, y)
|
||||
|
||||
def rms_norm_loop(g, x, w):
|
||||
gx, gw = x, w
|
||||
for _ in range(32):
|
||||
gx, gw = g(gx, gw, y)
|
||||
return gx, gw
|
||||
|
||||
time_fn(rms_norm_loop, g1, x, w)
|
||||
time_fn(rms_norm_loop, g2, x, w)
|
||||
time_fn(rms_norm_loop, mx.compile(g1), x, w)
|
||||
time_fn(rms_norm_loop, mx.compile(g2), x, w)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_rms_norm()
|
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.2 MiB |
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 746 KiB |
@@ -1,24 +1,16 @@
|
||||
Developer Documentation
|
||||
=======================
|
||||
|
||||
MLX provides a open and flexible backend to which users may add operations
|
||||
and specialized implementations without much hassle. While the library supplies
|
||||
efficient operations that can be used and composed for any number of
|
||||
applications, there may arise cases where new functionalities or highly
|
||||
optimized implementations are needed. For such cases, you may design and
|
||||
implement your own operations that link to and build on top of :mod:`mlx.core`.
|
||||
We will introduce the inner-workings of MLX and go over a simple example to
|
||||
learn the steps involved in adding new operations to MLX with your own CPU
|
||||
and GPU implementations.
|
||||
You can extend MLX with custom operations on the CPU or GPU. This guide
|
||||
explains how to do that with a simple example.
|
||||
|
||||
Introducing the Example
|
||||
-----------------------
|
||||
|
||||
Let's say that you would like an operation that takes in two arrays,
|
||||
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
|
||||
respectively, and then adds them together to get the result
|
||||
``z = alpha * x + beta * y``. Well, you can very easily do that by just
|
||||
writing out a function as follows:
|
||||
Let's say you would like an operation that takes in two arrays, ``x`` and
|
||||
``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
|
||||
and then adds them together to get the result ``z = alpha * x + beta * y``.
|
||||
You can do that in MLX directly:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -27,44 +19,35 @@ writing out a function as follows:
|
||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||
return alpha * x + beta * y
|
||||
|
||||
This function performs that operation while leaving the implementations and
|
||||
differentiation to MLX.
|
||||
This function performs that operation while leaving the implementation and
|
||||
function transformations to MLX.
|
||||
|
||||
However, you work with vector math libraries often and realize that the
|
||||
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``.
|
||||
You would really like the part of your applications that does this operation
|
||||
on the CPU to be very fast - so you decide that you want it to rely on the
|
||||
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
|
||||
our assumptions on to you, let's also assume that you want to learn how to add
|
||||
your own implementation for the gradients of your new operation while going
|
||||
over the ins-and-outs of the MLX framework.
|
||||
However you may need to customize the underlying implementation, perhaps to
|
||||
make it faster or for custom differentiation. In this tutorial we will go
|
||||
through adding custom extensions. It will cover:
|
||||
|
||||
Well, what a coincidence! You are in the right place. Over the course of this
|
||||
example, we will learn:
|
||||
|
||||
* The structure of the MLX library from the frontend API to the backend implementations.
|
||||
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed).
|
||||
* How to implement your own GPU implementation using metal.
|
||||
* How to add your own ``vjp`` and ``jvp``.
|
||||
* How to build your implementations, link them to MLX, and bind them to python.
|
||||
* The structure of the MLX library.
|
||||
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
|
||||
* Implementing a GPU operation using metal.
|
||||
* Adding the ``vjp`` and ``jvp`` function transformation.
|
||||
* Building a custom extension and binding it to python.
|
||||
|
||||
Operations and Primitives
|
||||
-------------------------
|
||||
|
||||
In one sentence, operations in MLX build the computation graph, and primitives
|
||||
provide the rules for evaluation and transformations of said graph. Let's start
|
||||
by discussing operations in more detail.
|
||||
Operations in MLX build the computation graph. Primitives provide the rules for
|
||||
evaluating and transforming the graph. Let's start by discussing operations in
|
||||
more detail.
|
||||
|
||||
Operations
|
||||
^^^^^^^^^^^
|
||||
|
||||
Operations are the frontend functions that operate on arrays. They are defined
|
||||
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these
|
||||
operations in the Python API (:ref:`ops`).
|
||||
Operations are the front-end functions that operate on arrays. They are defined
|
||||
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
|
||||
|
||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``,
|
||||
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the
|
||||
C++ API:
|
||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
|
||||
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
|
||||
C++:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -83,10 +66,7 @@ C++ API:
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
|
||||
This operation itself can call other operations within it if needed. So, the
|
||||
simplest way to go about implementing this operation would be do so in terms
|
||||
of existing operations.
|
||||
The simplest way to this operation is in terms of existing operations:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -100,25 +80,23 @@ of existing operations.
|
||||
// Scale x and y on the provided stream
|
||||
auto ax = multiply(array(alpha), x, s);
|
||||
auto by = multiply(array(beta), y, s);
|
||||
|
||||
|
||||
// Add and return
|
||||
return add(ax, by, s);
|
||||
}
|
||||
|
||||
However, as we discussed earlier, this is not our goal. The operations themselves
|
||||
do not contain the implementations that act on the data, nor do they contain the
|
||||
rules of transformations. Rather, they are an easy to use interface that build
|
||||
on top of the building blocks we call :class:`Primitive`.
|
||||
The operations themselves do not contain the implementations that act on the
|
||||
data, nor do they contain the rules of transformations. Rather, they are an
|
||||
easy to use interface that use :class:`Primitive` building blocks.
|
||||
|
||||
Primitives
|
||||
^^^^^^^^^^^
|
||||
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create an output given a set of input :class:`array` . Further,
|
||||
a :class:`Primitive` is a class that contains rules on how it is evaluated
|
||||
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and
|
||||
``jvp``. These words on their own can be a bit abstract, so lets take a step
|
||||
back and go to our example to give ourselves a more concrete image.
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create outputs arrays given a input arrays. Further, a
|
||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
||||
more concrete:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -134,11 +112,15 @@ back and go to our example to give ourselves a more concrete image.
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) override;
|
||||
void eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
array jvp(
|
||||
std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
@@ -147,7 +129,8 @@ back and go to our example to give ourselves a more concrete image.
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) override;
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself across
|
||||
@@ -155,7 +138,7 @@ back and go to our example to give ourselves a more concrete image.
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<array, int> vmap(
|
||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
@@ -175,22 +158,22 @@ back and go to our example to give ourselves a more concrete image.
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
The :class:`Axpby` class derives from the base :class:`Primitive` class and
|
||||
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and
|
||||
``beta`` as parameters. It then provides implementations of how the array ``out``
|
||||
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and
|
||||
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in
|
||||
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`.
|
||||
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
|
||||
:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
|
||||
implementations of how the output array is produced given the inputs through
|
||||
:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
|
||||
of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
|
||||
:meth:`Axpby::vmap`.
|
||||
|
||||
Using the Primitives
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
Using the Primitive
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Operations can use this :class:`Primitive` to add a new :class:`array` to
|
||||
the computation graph. An :class:`array` can be constructed by providing its
|
||||
data type, shape, the :class:`Primitive` that computes it, and the
|
||||
:class:`array` inputs that are passed to the primitive.
|
||||
Operations can use this :class:`Primitive` to add a new :class:`array` to the
|
||||
computation graph. An :class:`array` can be constructed by providing its data
|
||||
type, shape, the :class:`Primitive` that computes it, and the :class:`array`
|
||||
inputs that are passed to the primitive.
|
||||
|
||||
Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
||||
Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -223,7 +206,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
std::make_unique<Axpby>(to_stream(s), alpha, beta),
|
||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
@@ -238,27 +221,26 @@ This operation now handles the following:
|
||||
Implementing the Primitive
|
||||
--------------------------
|
||||
|
||||
No computation happens when we call the operation alone. In effect, the
|
||||
operation only builds the computation graph. When we evaluate the output
|
||||
array, MLX schedules the execution of the computation graph, and calls
|
||||
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the
|
||||
stream/device specified by the user.
|
||||
No computation happens when we call the operation alone. The operation only
|
||||
builds the computation graph. When we evaluate the output array, MLX schedules
|
||||
the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
|
||||
:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
|
||||
|
||||
.. warning::
|
||||
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
|
||||
no memory has been allocated for the output array. It falls on the implementation
|
||||
of these functions to allocate memory as needed
|
||||
of these functions to allocate memory as needed.
|
||||
|
||||
Implementing the CPU Backend
|
||||
Implementing the CPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Let's start by trying to implement a naive and generic version of
|
||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
||||
Let's start by implementing a naive and generic version of
|
||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
||||
|
||||
Our naive method will go over each element of the output array, find the
|
||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||
pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
Our naive method will go over each element of the output array, find the
|
||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||
point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -296,19 +278,19 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
}
|
||||
}
|
||||
|
||||
Now, we would like our implementation to be able to do this pointwise operation
|
||||
for all incoming floating point arrays. Accordingly, we add dispatches for
|
||||
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error
|
||||
if we encounter an unexpected type.
|
||||
Our implementation should work for all incoming floating point arrays.
|
||||
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
||||
``complex64``. We throw an error if we encounter an unexpected type.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
@@ -321,28 +303,26 @@ if we encounter an unexpected type.
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
"[Axpby] Only supports floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
We have a fallback implementation! Now, to do what we are really here to do.
|
||||
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_
|
||||
framework? Well, there are 3 complications to keep in mind:
|
||||
This is good as a fallback implementation. We can use the ``axpby`` routine
|
||||
provided by the Accelerate_ framework for a faster implementation in certain
|
||||
cases:
|
||||
|
||||
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
||||
floats. We can only direct to it for ``float32`` types
|
||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements
|
||||
have fixed strides between them. Possibly due to broadcasts and transposes,
|
||||
we aren't guaranteed that the inputs fit this requirement. We can
|
||||
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or
|
||||
column contiguous.
|
||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace.
|
||||
MLX expects to write out the answer to a new array. We must copy the elements
|
||||
of ``y`` into the output array and use that as an input to ``axpby``
|
||||
floats. We can only use it for ``float32`` types.
|
||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
|
||||
elements have fixed strides between them. We only direct to Accelerate
|
||||
if both ``x`` and ``y`` are row contiguous or column contiguous.
|
||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
|
||||
MLX expects to write the output to a new array. We must copy the elements
|
||||
of ``y`` into the output and use that as an input to ``axpby``.
|
||||
|
||||
Let's write out an implementation that uses Accelerate in the right conditions.
|
||||
It must simply allocate data for the output, copy elements of ``y`` into it,
|
||||
and then call the :meth:`catlas_saxpby` from accelerate.
|
||||
Let's write an implementation that uses Accelerate in the right conditions.
|
||||
It allocates data for the output, copies ``y`` into it, and then calls the
|
||||
:func:`catlas_saxpby` from accelerate.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -356,17 +336,7 @@ and then call the :meth:`catlas_saxpby` from accelerate.
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
// Y = (alpha * X) + (beta * Y) in place
|
||||
// To use it, we first copy the data in y over to the output array
|
||||
|
||||
// This specialization requires both x and y be contiguous in the same mode
|
||||
// i.e: corresponding linear indices in both point to corresponding elements
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
|
||||
y.data_size(),
|
||||
y.strides(),
|
||||
y.flags());
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
@@ -389,18 +359,20 @@ and then call the :meth:`catlas_saxpby` from accelerate.
|
||||
/* INCY = */ 1);
|
||||
}
|
||||
|
||||
Great! But what about the inputs that do not fit the criteria for accelerate?
|
||||
Luckily, we can always just direct back to :meth:`Axpby::eval`.
|
||||
|
||||
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
|
||||
For inputs that do not fit the criteria for accelerate, we fall back to
|
||||
:meth:`Axpby::eval`. With this in mind, let's finish our
|
||||
:meth:`Axpby::eval_cpu`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
@@ -410,35 +382,33 @@ With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to common backend if specializations are not available
|
||||
eval(inputs, out);
|
||||
// Fall back to common back-end if specializations are not available
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
We have now hit a milestone! Just this much is enough to run the operation
|
||||
:meth:`axpby` on a CPU stream!
|
||||
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
|
||||
you do not plan on running the operation on the GPU or using transforms on
|
||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
||||
|
||||
If you do not plan on running the operation on the GPU or using transforms on
|
||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
||||
|
||||
Implementing the GPU Backend
|
||||
Implementing the GPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Apple silicon devices address their GPUs using the Metal_ shading language, and
|
||||
all GPU kernels in MLX are written using metal.
|
||||
Apple silicon devices address their GPUs using the Metal_ shading language, and
|
||||
GPU kernels in MLX are written using Metal.
|
||||
|
||||
.. note::
|
||||
|
||||
Here are some helpful resources if you are new to metal!
|
||||
Here are some helpful resources if you are new to Metal:
|
||||
|
||||
* A walkthrough of the metal compute pipeline: `Metal Example`_
|
||||
* Documentation for metal shading language: `Metal Specification`_
|
||||
* Using metal from C++: `Metal-cpp`_
|
||||
|
||||
Let's keep the GPU algorithm simple. We will launch exactly as many threads
|
||||
as there are elements in the output. Each thread will pick the element it needs
|
||||
from ``x`` and ``y``, do the pointwise operation, and then update its assigned
|
||||
element in the output.
|
||||
Let's keep the GPU kernel simple. We will launch exactly as many threads as
|
||||
there are elements in the output. Each thread will pick the element it needs
|
||||
from ``x`` and ``y``, do the point-wise operation, and update its assigned
|
||||
element in the output.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -457,15 +427,14 @@ element in the output.
|
||||
// Convert linear indices to offsets in array
|
||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
||||
|
||||
|
||||
// Do the operation and update the output
|
||||
out[index] =
|
||||
out[index] =
|
||||
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
||||
}
|
||||
|
||||
We then need to instantiate this template for all floating point types and give
|
||||
each instantiation a unique host name so we can identify the right kernel for
|
||||
each data type.
|
||||
each instantiation a unique host name so we can identify it.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -488,29 +457,21 @@ each data type.
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
|
||||
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
|
||||
will see later in :ref:`Building with CMake`. In the following example, we
|
||||
assume that the library ``mlx_ext.metallib`` will always be co-located with
|
||||
the executable/ shared-library calling the :meth:`register_library` function.
|
||||
The :meth:`register_library` function takes the library's name and potential
|
||||
path (or in this case, a function that can produce the path of the metal
|
||||
library) and tries to load that library if it hasn't already been registered
|
||||
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
|
||||
it is important to package your C++ library with the metal library. We will
|
||||
go over this process in more detail later.
|
||||
|
||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions
|
||||
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
||||
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||
below.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Prepare inputs
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Each primitive carries the stream it should execute on
|
||||
// and each stream carries its device identifiers
|
||||
@@ -518,10 +479,10 @@ below.
|
||||
// We get the needed metal device using the stream
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Allocate output memory
|
||||
// Allocate output memory
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
// Resolve name of kernel
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
|
||||
@@ -552,7 +513,7 @@ below.
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
|
||||
// Encode shape, strides and ndim
|
||||
// Encode shape, strides and ndim
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
@@ -575,28 +536,25 @@ below.
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
|
||||
A few things to note about MLX and metal before moving on. MLX keeps track
|
||||
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder`
|
||||
to give us the active metal compute command encoder instead of building a
|
||||
new one and calling :meth:`compute_encoder->end_encoding` at the end.
|
||||
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
||||
until some specified limit is hit or the compute encoder needs to be flushed
|
||||
for synchronization. MLX also handles enqueuing and committing the associated
|
||||
command buffers as needed. We suggest taking a deeper dive into
|
||||
:class:`metal::Device` if you would like to study this routine further.
|
||||
A few things to note about MLX and Metal before moving on. MLX keeps track of
|
||||
the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
|
||||
associated. We rely on :meth:`d.get_command_encoder` to give us the active
|
||||
metal compute command encoder instead of building a new one and calling
|
||||
:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
|
||||
pipelines) to the active command buffer until some specified limit is hit or
|
||||
the command buffer needs to be flushed for synchronization.
|
||||
|
||||
Primitive Transforms
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Now that we have come this far, let's also learn how to add implementations to
|
||||
transformations in a :class:`Primitive`. These transformations can be built on
|
||||
top of our operations, including the one we just defined now. Which then gives
|
||||
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
||||
Next, let's add implementations for transformations in a :class:`Primitive`.
|
||||
These transformations can be built on top of other operations, including the
|
||||
one we just defined:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
array Axpby::jvp(
|
||||
std::vector<array> Axpby::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
@@ -611,12 +569,12 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
||||
if (argnums.size() > 1) {
|
||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return multiply(scale_arr, tangents[0], stream());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
|
||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -625,34 +583,35 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> Axpby::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<int>& /* unused */) {
|
||||
// Reverse mode diff
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
auto scale = arg == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, cotan.dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotan, stream()));
|
||||
auto scale_arr = array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
Finally, you need not have a transformation fully defined to start using your
|
||||
own :class:`Primitive`.
|
||||
Note, a transformation does not need to be fully defined to start using
|
||||
the :class:`Primitive`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Vectorize primitive along given axis */
|
||||
std::pair<array, int> Axpby::vmap(
|
||||
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||
throw std::runtime_error("[Axpby] vmap not implemented.");
|
||||
}
|
||||
|
||||
Building and Binding
|
||||
--------------------
|
||||
|
||||
Let's look at the overall directory structure first.
|
||||
Let's look at the overall directory structure first.
|
||||
|
||||
| extensions
|
||||
| ├── axpby
|
||||
@@ -666,40 +625,39 @@ Let's look at the overall directory structure first.
|
||||
| └── setup.py
|
||||
|
||||
* ``extensions/axpby/`` defines the C++ extension library
|
||||
* ``extensions/mlx_sample_extensions`` sets out the structure for the
|
||||
associated python package
|
||||
* ``extensions/bindings.cpp`` provides python bindings for our operation
|
||||
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||
python bindings
|
||||
* ``extensions/mlx_sample_extensions`` sets out the structure for the
|
||||
associated Python package
|
||||
* ``extensions/bindings.cpp`` provides Python bindings for our operation
|
||||
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||
Python bindings
|
||||
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
|
||||
the python package
|
||||
the Python package
|
||||
|
||||
Binding to Python
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We use PyBind11_ to build a Python API for the C++ library. Since bindings for
|
||||
We use nanobind_ to build a Python API for the C++ library. Since bindings for
|
||||
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
|
||||
already provided, adding our :meth:`axpby` is simple!
|
||||
already provided, adding our :meth:`axpby` is simple.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
m.doc() = "Sample C++ and metal extensions for MLX";
|
||||
NB_MODULE(_ext, m) {
|
||||
m.doc() = "Sample extension for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
py::pos_only(),
|
||||
"alpha"_a,
|
||||
"beta"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = py::none(),
|
||||
R"pbdoc(
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
R"(
|
||||
Scale and sum two vectors element-wise
|
||||
``z = alpha * x + beta * y``
|
||||
|
||||
|
||||
Follows numpy style broadcasting between ``x`` and ``y``
|
||||
Inputs are upcasted to floats if needed
|
||||
|
||||
@@ -711,17 +669,17 @@ already provided, adding our :meth:`axpby` is simple!
|
||||
|
||||
Returns:
|
||||
array: ``alpha * x + beta * y``
|
||||
)pbdoc");
|
||||
)");
|
||||
}
|
||||
|
||||
Most of the complexity in the above example comes from additional bells and
|
||||
Most of the complexity in the above example comes from additional bells and
|
||||
whistles such as the literal names and doc-strings.
|
||||
|
||||
.. warning::
|
||||
|
||||
:mod:`mlx.core` needs to be imported before importing
|
||||
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to
|
||||
ensure that the casters for :mod:`mlx.core` components like
|
||||
:mod:`mlx.core` must be imported before importing
|
||||
:mod:`mlx_sample_extensions` as defined by the nanobind module above to
|
||||
ensure that the casters for :mod:`mlx.core` components like
|
||||
:class:`mlx.core.array` are available.
|
||||
|
||||
.. _Building with CMake:
|
||||
@@ -729,8 +687,8 @@ whistles such as the literal names and doc-strings.
|
||||
Building with CMake
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Building the C++ extension library itself is simple, it only requires that you
|
||||
``find_package(MLX CONFIG)`` and then link it to your library.
|
||||
Building the C++ extension library only requires that you ``find_package(MLX
|
||||
CONFIG)`` and then link it to your library.
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
@@ -752,12 +710,12 @@ Building the C++ extension library itself is simple, it only requires that you
|
||||
# Link to mlx
|
||||
target_link_libraries(mlx_ext PUBLIC mlx)
|
||||
|
||||
We also need to build the attached metal library. For convenience, we provide a
|
||||
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
|
||||
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
|
||||
automatically imported with MLX package).
|
||||
We also need to build the attached Metal library. For convenience, we provide a
|
||||
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
|
||||
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
|
||||
automatically imported with MLX package).
|
||||
|
||||
Here is what that looks like in practice!
|
||||
Here is what that looks like in practice:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
@@ -779,27 +737,29 @@ Here is what that looks like in practice!
|
||||
|
||||
endif()
|
||||
|
||||
Finally, we build the Pybind11_ bindings
|
||||
Finally, we build the nanobind_ bindings
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
pybind11_add_module(
|
||||
mlx_sample_extensions
|
||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||
nanobind_add_module(
|
||||
_ext
|
||||
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
||||
NB_DOMAIN mlx
|
||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||
)
|
||||
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
||||
target_link_libraries(_ext PRIVATE mlx_ext)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
||||
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
|
||||
endif()
|
||||
|
||||
Building with ``setuptools``
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Once we have set out the CMake build rules as described above, we can use the
|
||||
build utilities defined in :mod:`mlx.extension` for a simple build process.
|
||||
build utilities defined in :mod:`mlx.extension`:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
from mlx import extension
|
||||
from setuptools import setup
|
||||
@@ -809,48 +769,50 @@ build utilities defined in :mod:`mlx.extension` for a simple build process.
|
||||
name="mlx_sample_extensions",
|
||||
version="0.0.0",
|
||||
description="Sample C++ and Metal extensions for MLX primitives.",
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
|
||||
cmdclass={"build_ext": extension.CMakeBuild},
|
||||
packages = ["mlx_sample_extensions"],
|
||||
package_dir = {"": "mlx_sample_extensions"},
|
||||
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
|
||||
packages=["mlx_sample_extensions"],
|
||||
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||
extras_require={"dev":[]},
|
||||
zip_safe=False,
|
||||
python_requires=">=3.7",
|
||||
python_requires=">=3.8",
|
||||
)
|
||||
|
||||
.. note::
|
||||
We treat ``extensions/mlx_sample_extensions`` as the package directory
|
||||
even though it only contains a ``__init__.py`` to ensure the following:
|
||||
|
||||
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
|
||||
* The C++ extension library and the metal library are co-located with the python
|
||||
bindings and copied together if the package is installed
|
||||
|
||||
You can build inplace for development using
|
||||
* :mod:`mlx.core` must be imported before importing :mod:`_ext`
|
||||
* The C++ extension library and the metal library are co-located with the python
|
||||
bindings and copied together if the package is installed
|
||||
|
||||
To build the package, first install the build dependencies with ``pip install
|
||||
-r requirements.txt``. You can then build inplace for development using
|
||||
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
|
||||
|
||||
This will result in a directory structure as follows:
|
||||
This results in the directory structure:
|
||||
|
||||
| extensions
|
||||
| ├── mlx_sample_extensions
|
||||
| │ ├── __init__.py
|
||||
| │ ├── libmlx_ext.dylib # C++ extension library
|
||||
| │ ├── mlx_ext.metallib # Metal library
|
||||
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
|
||||
| │ └── _ext.cpython-3x-darwin.so # Python Binding
|
||||
| ...
|
||||
|
||||
When you try to install using the command ``python -m pip install .``
|
||||
(in ``extensions/``), the package will be installed with the same structure as
|
||||
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
|
||||
copied along with the python binding since they are specified as ``package_data``.
|
||||
When you try to install using the command ``python -m pip install .`` (in
|
||||
``extensions/``), the package will be installed with the same structure as
|
||||
``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
|
||||
copied along with the Python binding since they are specified as
|
||||
``package_data``.
|
||||
|
||||
Usage
|
||||
-----
|
||||
|
||||
After installing the extension as described above, you should be able to simply
|
||||
import the python package and play with it as you would any other MLX operation!
|
||||
After installing the extension as described above, you should be able to simply
|
||||
import the Python package and play with it as you would any other MLX operation.
|
||||
|
||||
Let's looks at a simple script and it's results!
|
||||
Let's look at a simple script and its results:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -874,12 +836,12 @@ Output:
|
||||
c correctness: True
|
||||
|
||||
Results
|
||||
^^^^^^^^^^^^^^^^
|
||||
^^^^^^^
|
||||
|
||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||
with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||
with the naive :meth:`simple_axpby` we first defined on the CPU.
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_sample_extensions import axpby
|
||||
@@ -898,7 +860,7 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
||||
alpha = 4.0
|
||||
beta = 2.0
|
||||
|
||||
mx.eval((x, y))
|
||||
mx.eval(x, y)
|
||||
|
||||
def bench(f):
|
||||
# Warm up
|
||||
@@ -919,30 +881,23 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
||||
|
||||
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
||||
|
||||
Results:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Simple axpby: 0.114 s | Custom axpby: 0.109 s
|
||||
|
||||
We see some modest improvements right away!
|
||||
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
|
||||
modest improvements right away!
|
||||
|
||||
This operation is now good to be used to build other operations, in
|
||||
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
|
||||
:meth:`grad`!
|
||||
:meth:`grad`.
|
||||
|
||||
Scripts
|
||||
-------
|
||||
|
||||
.. admonition:: Download the code
|
||||
|
||||
The full example code is available in `mlx <code>`_.
|
||||
|
||||
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
|
||||
The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
|
||||
|
||||
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
|
||||
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
|
||||
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
|
||||
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
|
||||
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/
|
||||
.. _nanobind: https://nanobind.readthedocs.io/en/latest/
|
||||
|
69
docs/src/dev/metal_debugger.rst
Normal file
69
docs/src/dev/metal_debugger.rst
Normal file
@@ -0,0 +1,69 @@
|
||||
Metal Debugger
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Profiling is a key step for performance optimization. You can build MLX with
|
||||
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
|
||||
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
|
||||
|
||||
* Records source during Metal compilation, for later inspection while
|
||||
debugging.
|
||||
* Labels Metal objects such as command queues, improving capture readability.
|
||||
|
||||
To build with debugging enabled in Python prepend
|
||||
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
|
||||
|
||||
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
|
||||
work.
|
||||
|
||||
.. note::
|
||||
|
||||
To capture a GPU trace you must run the application with
|
||||
``MTL_CAPTURE_ENABLED=1``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
a = mx.random.uniform(shape=(512, 512))
|
||||
b = mx.random.uniform(shape=(512, 512))
|
||||
mx.eval(a, b)
|
||||
|
||||
trace_file = "mlx_trace.gputrace"
|
||||
|
||||
if not mx.metal.start_capture(trace_file):
|
||||
print("Make sure to run with MTL_CAPTURE_ENABLED=1 and "
|
||||
f"that the path {trace_file} does not already exist.")
|
||||
exit(1)
|
||||
|
||||
for _ in range(10):
|
||||
mx.eval(mx.add(a, b))
|
||||
|
||||
mx.metal.stop_capture()
|
||||
|
||||
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
|
||||
has a great overview of all operations. Checkout the `Metal debugger
|
||||
documentation`_ for more information.
|
||||
|
||||
.. image:: ../_static/metal_debugger/capture.png
|
||||
:class: dark-light
|
||||
|
||||
Xcode Workflow
|
||||
--------------
|
||||
|
||||
You can skip saving to a path by running within Xcode. First, generate an
|
||||
Xcode project using CMake.
|
||||
|
||||
.. code-block::
|
||||
|
||||
mkdir build && cd build
|
||||
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
|
||||
open mlx.xcodeproj
|
||||
|
||||
Select the ``metal_capture`` example schema and run.
|
||||
|
||||
.. image:: ../_static/metal_debugger/schema.png
|
||||
:class: dark-light
|
||||
|
||||
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger
|
@@ -58,6 +58,7 @@ are the CPU and GPU.
|
||||
:maxdepth: 1
|
||||
|
||||
python/array
|
||||
python/data_types
|
||||
python/devices_and_streams
|
||||
python/ops
|
||||
python/random
|
||||
@@ -81,3 +82,4 @@ are the CPU and GPU.
|
||||
:maxdepth: 1
|
||||
|
||||
dev/extensions
|
||||
dev/metal_debugger
|
||||
|
@@ -155,6 +155,8 @@ should point to the path to the built metal library.
|
||||
- ON
|
||||
* - MLX_BUILD_PYTHON_BINDINGS
|
||||
- OFF
|
||||
* - MLX_METAL_DEBUG
|
||||
- OFF
|
||||
|
||||
|
||||
.. note::
|
||||
|
@@ -19,7 +19,6 @@ Array
|
||||
array.ndim
|
||||
array.shape
|
||||
array.size
|
||||
Dtype
|
||||
array.abs
|
||||
array.all
|
||||
array.any
|
||||
@@ -32,7 +31,6 @@ Array
|
||||
array.cumsum
|
||||
array.diag
|
||||
array.diagonal
|
||||
array.dtype
|
||||
array.exp
|
||||
array.flatten
|
||||
array.log
|
||||
|
@@ -1,7 +1,5 @@
|
||||
.. _data_types:
|
||||
|
||||
:orphan:
|
||||
|
||||
Data Types
|
||||
==========
|
||||
|
||||
@@ -56,3 +54,15 @@ The default floating point type is ``float32`` and the default integer type is
|
||||
* - ``complex64``
|
||||
- 8
|
||||
- 64-bit complex float
|
||||
|
||||
|
||||
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
||||
documentation for more information. Use :func:`issubdtype` to determine if one
|
||||
``dtype`` (or category) is a subtype of another category.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Dtype
|
||||
DtypeCategory
|
||||
issubdtype
|
||||
|
@@ -3,7 +3,7 @@ Metal
|
||||
|
||||
.. currentmodule:: mlx.core.metal
|
||||
|
||||
.. autosummary::
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
is_available
|
||||
@@ -12,3 +12,5 @@ Metal
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
start_capture
|
||||
stop_capture
|
||||
|
@@ -30,6 +30,7 @@ Module
|
||||
Module.named_modules
|
||||
Module.parameters
|
||||
Module.save_weights
|
||||
Module.set_dtype
|
||||
Module.train
|
||||
Module.trainable_parameters
|
||||
Module.unfreeze
|
||||
|
@@ -5,13 +5,13 @@ Operations
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
abs
|
||||
add
|
||||
all
|
||||
allclose
|
||||
allclose
|
||||
any
|
||||
arange
|
||||
arccos
|
||||
@@ -51,6 +51,7 @@ Operations
|
||||
erf
|
||||
erfinv
|
||||
exp
|
||||
expm1
|
||||
expand_dims
|
||||
eye
|
||||
flatten
|
||||
@@ -62,10 +63,10 @@ Operations
|
||||
identity
|
||||
inner
|
||||
isclose
|
||||
isnan
|
||||
isposinf
|
||||
isneginf
|
||||
isinf
|
||||
isnan
|
||||
isneginf
|
||||
isposinf
|
||||
less
|
||||
less_equal
|
||||
linspace
|
||||
@@ -83,6 +84,7 @@ Operations
|
||||
max
|
||||
maximum
|
||||
mean
|
||||
meshgrid
|
||||
min
|
||||
minimum
|
||||
moveaxis
|
||||
@@ -117,6 +119,7 @@ Operations
|
||||
square
|
||||
squeeze
|
||||
stack
|
||||
std
|
||||
stop_gradient
|
||||
subtract
|
||||
sum
|
||||
|
@@ -38,6 +38,7 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
|
||||
gumbel
|
||||
key
|
||||
normal
|
||||
multivariate_normal
|
||||
randint
|
||||
seed
|
||||
split
|
||||
|
@@ -49,7 +49,7 @@ it will be added. You can load the array with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> mx.load("array.npy", a)
|
||||
>>> mx.load("array.npy")
|
||||
array([1], dtype=float32)
|
||||
|
||||
Here's an example of saving several arrays to a single file:
|
||||
|
@@ -8,3 +8,4 @@ endfunction(build_example)
|
||||
build_example(tutorial.cpp)
|
||||
build_example(linear_regression.cpp)
|
||||
build_example(logistic_regression.cpp)
|
||||
build_example(metal_capture.cpp)
|
||||
|
31
examples/cpp/metal_capture.cpp
Normal file
31
examples/cpp/metal_capture.cpp
Normal file
@@ -0,0 +1,31 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
// To use Metal debugging and profiling:
|
||||
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
|
||||
// 2. Run with MTL_CAPTURE_ENABLED=1.
|
||||
assert(metal::start_capture("mlx_trace.gputrace"));
|
||||
|
||||
// Start at index two because the default GPU and CPU streams have indices
|
||||
// zero and one, respectively. This naming matches the label assigned to each
|
||||
// stream's command queue.
|
||||
auto s2 = new_stream(Device::gpu);
|
||||
auto s3 = new_stream(Device::gpu);
|
||||
|
||||
auto a = arange(1.f, 10.f, 1.f, float32, s2);
|
||||
auto b = arange(1.f, 10.f, 1.f, float32, s3);
|
||||
auto x = add(a, a, s2);
|
||||
auto y = add(b, b, s3);
|
||||
|
||||
// The multiply will happen on the default stream.
|
||||
std::cout << multiply(x, y) << std::endl;
|
||||
|
||||
metal::stop_capture();
|
||||
}
|
@@ -1,6 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(mlx_sample_extensions LANGUAGES CXX)
|
||||
project(_ext LANGUAGES CXX)
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
@@ -11,8 +11,12 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
||||
|
||||
# ----------------------------- Dependencies -----------------------------
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
find_package(Python COMPONENTS Interpreter Development)
|
||||
find_package(pybind11 CONFIG REQUIRED)
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
# ----------------------------- Extensions -----------------------------
|
||||
|
||||
@@ -38,7 +42,6 @@ target_link_libraries(mlx_ext PUBLIC mlx)
|
||||
|
||||
# Build metallib
|
||||
if(MLX_BUILD_METAL)
|
||||
|
||||
mlx_build_metallib(
|
||||
TARGET mlx_ext_metallib
|
||||
TITLE mlx_ext
|
||||
@@ -54,13 +57,15 @@ if(MLX_BUILD_METAL)
|
||||
|
||||
endif()
|
||||
|
||||
# ----------------------------- Pybind -----------------------------
|
||||
pybind11_add_module(
|
||||
mlx_sample_extensions
|
||||
# ----------------------------- Python Bindings -----------------------------
|
||||
nanobind_add_module(
|
||||
_ext
|
||||
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
||||
NB_DOMAIN mlx
|
||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||
)
|
||||
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
||||
target_link_libraries(_ext PRIVATE mlx_ext)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
||||
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
|
||||
endif()
|
||||
|
18
examples/extensions/README.md
Normal file
18
examples/extensions/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
## Build the extensions
|
||||
|
||||
```
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
For faster builds during development, you can also pre-install the requirements:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And then run:
|
||||
|
||||
```
|
||||
python setup.py build_ext -j8 --inplace
|
||||
```
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
@@ -43,7 +43,7 @@ array axpby(
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = is_floating_point(promoted_dtype)
|
||||
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
|
||||
@@ -61,7 +61,7 @@ array axpby(
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
std::make_unique<Axpby>(to_stream(s), alpha, beta),
|
||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
@@ -106,12 +106,12 @@ void axpby_impl(
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& out_arr) {
|
||||
auto out = out_arr[0];
|
||||
std::vector<array>& outputs) {
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
@@ -150,11 +150,7 @@ void axpby_impl_accelerate(
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
|
||||
y.data_size(),
|
||||
y.strides(),
|
||||
y.flags());
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
@@ -180,11 +176,11 @@ void axpby_impl_accelerate(
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outarr) {
|
||||
auto out = outarr[0];
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
@@ -195,7 +191,7 @@ void Axpby::eval_cpu(
|
||||
}
|
||||
|
||||
// Fall back to common backend if specializations are not available
|
||||
eval(inputs, outarr);
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
#else // Accelerate not available
|
||||
@@ -203,8 +199,8 @@ void Axpby::eval_cpu(
|
||||
/** Evaluate primitive on CPU falling back to common backend */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& out) {
|
||||
eval(inputs, out);
|
||||
const std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -218,12 +214,12 @@ void Axpby::eval_cpu(
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outarr) {
|
||||
std::vector<array>& outputs) {
|
||||
// Prepare inputs
|
||||
auto out = outarr[0];
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Each primitive carries the stream it should execute on
|
||||
// and each stream carries its device identifiers
|
||||
@@ -372,4 +368,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
|
||||
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
@@ -42,9 +42,9 @@ class Axpby : public Primitive {
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& out)
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& out)
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
@@ -83,7 +83,7 @@ class Axpby : public Primitive {
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& out);
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,31 +1,31 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
|
||||
#include "axpby/axpby.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
m.doc() = "Sample C++ and metal extensions for MLX";
|
||||
NB_MODULE(_ext, m) {
|
||||
m.doc() = "Sample extension for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
py::pos_only(),
|
||||
"alpha"_a,
|
||||
"beta"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = py::none(),
|
||||
R"pbdoc(
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
R"(
|
||||
Scale and sum two vectors element-wise
|
||||
``z = alpha * x + beta * y``
|
||||
|
||||
|
||||
Follows numpy style broadcasting between ``x`` and ``y``
|
||||
Inputs are upcasted to floats if needed
|
||||
|
||||
@@ -37,5 +37,5 @@ PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
|
||||
Returns:
|
||||
array: ``alpha * x + beta * y``
|
||||
)pbdoc");
|
||||
}
|
||||
)");
|
||||
}
|
||||
|
@@ -1,3 +1,8 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24", "mlx @ git+https://github.com/mlx-explore/mlx@main"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"cmake>=3.24",
|
||||
"mlx>=0.9.0",
|
||||
"nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
4
examples/extensions/requirements.txt
Normal file
4
examples/extensions/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.24
|
||||
mlx>=0.9.0
|
||||
nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
@@ -9,11 +9,11 @@ if __name__ == "__main__":
|
||||
name="mlx_sample_extensions",
|
||||
version="0.0.0",
|
||||
description="Sample C++ and Metal extensions for MLX primitives.",
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
|
||||
cmdclass={"build_ext": extension.CMakeBuild},
|
||||
packages=["mlx_sample_extensions"],
|
||||
package_dir={"": "."},
|
||||
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||
extras_require={"dev": []},
|
||||
zip_safe=False,
|
||||
python_requires=">=3.8",
|
||||
)
|
||||
|
@@ -93,7 +93,9 @@ void array::detach() {
|
||||
}
|
||||
|
||||
void array::eval() {
|
||||
mlx::core::eval({*this});
|
||||
if (!is_evaled()) {
|
||||
mlx::core::eval({*this});
|
||||
}
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
@@ -190,6 +192,36 @@ array::ArrayDesc::ArrayDesc(
|
||||
init();
|
||||
}
|
||||
|
||||
array::ArrayDesc::~ArrayDesc() {
|
||||
// When an array description is destroyed it will delete a bunch of arrays
|
||||
// that may also destory their corresponding descriptions and so on and so
|
||||
// forth.
|
||||
//
|
||||
// This calls recursively the destructor and can result in stack overflow, we
|
||||
// instead put them in a vector and destroy them one at a time resulting in a
|
||||
// max stack depth of 2.
|
||||
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
|
||||
|
||||
for (array& a : inputs) {
|
||||
if (a.array_desc_.use_count() == 1) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
}
|
||||
}
|
||||
|
||||
while (!for_deletion.empty()) {
|
||||
// top is going to be deleted at the end of the block *after* the arrays
|
||||
// with inputs have been moved into the vector
|
||||
auto top = std::move(for_deletion.back());
|
||||
for_deletion.pop_back();
|
||||
|
||||
for (array& a : top->inputs) {
|
||||
if (a.array_desc_.use_count() == 1) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||
: arr(arr), idx(idx) {
|
||||
if (arr.ndim() == 0) {
|
||||
|
24
mlx/array.h
24
mlx/array.h
@@ -256,6 +256,17 @@ class array {
|
||||
array_desc_->position = position;
|
||||
}
|
||||
|
||||
/** The i-th output of the array's primitive. */
|
||||
const array& output(int i) const {
|
||||
if (i == array_desc_->position) {
|
||||
return *this;
|
||||
} else if (i < array_desc_->position) {
|
||||
return siblings()[i];
|
||||
} else {
|
||||
return siblings()[i + 1];
|
||||
}
|
||||
};
|
||||
|
||||
/** The outputs of the array's primitive (i.e. this array and
|
||||
* its siblings) in the order the primitive expects. */
|
||||
std::vector<array> outputs() const {
|
||||
@@ -393,6 +404,8 @@ class array {
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs);
|
||||
|
||||
~ArrayDesc();
|
||||
|
||||
private:
|
||||
// Initialize size, strides, and other metadata
|
||||
void init();
|
||||
@@ -510,4 +523,15 @@ void array::init(It src) {
|
||||
}
|
||||
}
|
||||
|
||||
/* Utilities for determining whether a template parameter is array. */
|
||||
template <typename T>
|
||||
inline constexpr bool is_array_v =
|
||||
std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;
|
||||
|
||||
template <typename... T>
|
||||
inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
|
||||
|
||||
template <typename... T>
|
||||
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -301,7 +301,7 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (is_floating_point(out.dtype())) {
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -310,6 +310,19 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpm1f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -355,7 +368,7 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto size = in.data_size();
|
||||
vvlog1pf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (is_floating_point(out.dtype())) {
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
@@ -201,7 +201,7 @@ struct NeonFp16SimdOps {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT, typename Ops, int N>
|
||||
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
||||
void softmax(const array& in, array& out) {
|
||||
Ops ops;
|
||||
|
||||
@@ -218,13 +218,21 @@ void softmax(const array& in, array& out) {
|
||||
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum);
|
||||
VT vals;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vals = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vals[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vmaximum = ops.max(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
T maximum = ops.reduce_max(vmaximum);
|
||||
AccT maximum = ops.reduce_max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, *current_in_ptr);
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
@@ -234,18 +242,29 @@ void softmax(const array& in, array& out) {
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum));
|
||||
ops.store(current_out_ptr, vexp);
|
||||
*(VT*)current_out_ptr = vexp;
|
||||
VT vexp;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vexp = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vexp = ops.exp(ops.sub(vexp, maximum));
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, vexp);
|
||||
}
|
||||
vnormalizer = ops.add(vnormalizer, vexp);
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
T normalizer = ops.reduce_add(vnormalizer);
|
||||
AccT normalizer = ops.reduce_add(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
T _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = _exp;
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
if (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = _exp;
|
||||
}
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
@@ -254,14 +273,33 @@ void softmax(const array& in, array& out) {
|
||||
|
||||
// Normalize
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||
} else {
|
||||
VT vexp;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
current_out_ptr[i] = vexp[i];
|
||||
}
|
||||
current_in_ptr += N;
|
||||
}
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*current_out_ptr *= normalizer;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
current_out_ptr++;
|
||||
}
|
||||
}
|
||||
@@ -308,15 +346,29 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>(
|
||||
in, out);
|
||||
softmax<
|
||||
float,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
break;
|
||||
case float16:
|
||||
softmax<
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
if (precise_) {
|
||||
softmax<
|
||||
float16_t,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
} else {
|
||||
softmax<
|
||||
float16_t,
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
eval(inputs, out);
|
||||
|
@@ -179,18 +179,16 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||
} else {
|
||||
std::ostringstream err;
|
||||
err << "[logaddexp] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
}
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[logaddexp] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[logaddexp] Cannot compute logaddexp for arrays with"
|
||||
|
@@ -126,4 +126,102 @@ std::string build_lib_name(
|
||||
return os.str();
|
||||
}
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape) {
|
||||
bool contiguous = true;
|
||||
bool all_contig = true;
|
||||
bool all_row_contig = true;
|
||||
bool all_col_contig = true;
|
||||
int non_scalar_inputs = 0;
|
||||
for (const auto& x : inputs) {
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
non_scalar_inputs++;
|
||||
bool shape_eq = x.shape() == shape;
|
||||
all_contig &= (x.flags().contiguous && shape_eq);
|
||||
all_row_contig &= (x.flags().row_contiguous && shape_eq);
|
||||
all_col_contig &= (x.flags().col_contiguous && shape_eq);
|
||||
}
|
||||
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
|
||||
contiguous = false;
|
||||
} else if (non_scalar_inputs == 1 && !all_contig) {
|
||||
contiguous = false;
|
||||
} else if (non_scalar_inputs == 0 && !shape.empty()) {
|
||||
contiguous = false;
|
||||
}
|
||||
return contiguous;
|
||||
}
|
||||
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers /* = false */) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
std::vector<size_t> strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Correct size
|
||||
// - Not a scalar
|
||||
// - Donatable
|
||||
// - Not a constant
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o++].move_shared_buffer(in);
|
||||
} else {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
if (strides.empty() && in.size() == outputs[0].size()) {
|
||||
strides = in.strides();
|
||||
flags = in.flags();
|
||||
data_size = in.data_size();
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
}
|
||||
} else {
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o].move_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
} else {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
}
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -53,4 +53,18 @@ inline bool is_scalar(const array& x) {
|
||||
return x.ndim() == 0;
|
||||
}
|
||||
|
||||
// Check if we can use a contiguous operation given inputs and the output shape
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape);
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers = false);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -52,8 +52,25 @@ void* compile(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::string kernel_file_name;
|
||||
|
||||
// Deal with long kernel names. Maximum length for files on macOS is 255
|
||||
// characters. Clip file name with a little extra room and append a 16
|
||||
// character hash.
|
||||
constexpr int max_file_name_length = 245;
|
||||
if (kernel_name.size() > max_file_name_length) {
|
||||
std::ostringstream file_name;
|
||||
file_name
|
||||
<< std::string_view(kernel_name).substr(0, max_file_name_length - 16);
|
||||
auto file_id = std::hash<std::string>{}(kernel_name);
|
||||
file_name << "_" << std::hex << std::setw(16) << file_id << std::dec;
|
||||
kernel_file_name = file_name.str();
|
||||
} else {
|
||||
kernel_file_name = kernel_name;
|
||||
}
|
||||
|
||||
std::ostringstream shared_lib_name;
|
||||
shared_lib_name << "lib" << kernel_name << ".so";
|
||||
shared_lib_name << "lib" << kernel_file_name << ".so";
|
||||
auto shared_lib_path = get_temp_file(shared_lib_name.str());
|
||||
bool lib_exists = false;
|
||||
{
|
||||
@@ -64,7 +81,7 @@ void* compile(
|
||||
if (!lib_exists) {
|
||||
// Open source file and write source code to it
|
||||
std::ostringstream source_file_name;
|
||||
source_file_name << kernel_name << ".cpp";
|
||||
source_file_name << kernel_file_name << ".cpp";
|
||||
auto source_file_path = get_temp_file(source_file_name.str());
|
||||
|
||||
std::ofstream source_file(source_file_path);
|
||||
@@ -248,28 +265,7 @@ void Compiled::eval_cpu(
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
bool contiguous = true;
|
||||
{
|
||||
bool all_contig = true;
|
||||
bool all_row_contig = true;
|
||||
bool all_col_contig = true;
|
||||
int non_scalar_inputs = 0;
|
||||
for (auto& x : inputs) {
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
non_scalar_inputs++;
|
||||
bool shape_eq = x.shape() == shape;
|
||||
all_contig &= (x.flags().contiguous && shape_eq);
|
||||
all_row_contig &= (x.flags().row_contiguous && shape_eq);
|
||||
all_col_contig &= (x.flags().col_contiguous && shape_eq);
|
||||
}
|
||||
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
|
||||
contiguous = false;
|
||||
} else if (non_scalar_inputs == 1 && !all_contig) {
|
||||
contiguous = false;
|
||||
}
|
||||
}
|
||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
std::vector<void*> args;
|
||||
@@ -342,58 +338,8 @@ void Compiled::eval_cpu(
|
||||
fn_ptr = compile(kernel_name, kernel.str());
|
||||
}
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
std::vector<size_t> strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
if (strides.empty() && in.size() == outputs[0].size()) {
|
||||
strides = in.strides();
|
||||
flags = in.flags();
|
||||
data_size = in.data_size();
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
}
|
||||
} else {
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
||||
|
||||
for (auto& x : outputs) {
|
||||
args.push_back(x.data<void>());
|
||||
|
@@ -272,7 +272,7 @@ inline void copy_general_general(const array& src, array& dst) {
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename... Args>
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
copy_single<SrcT, DstT>(src, dst);
|
||||
@@ -281,54 +281,54 @@ void copy(const array& src, array& dst, CopyType ctype, Args... args) {
|
||||
copy_vector<SrcT, DstT>(src, dst);
|
||||
return;
|
||||
case CopyType::General:
|
||||
copy_general<SrcT, DstT>(src, dst, args...);
|
||||
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
return;
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_general_general<SrcT, DstT>(src, dst, args...);
|
||||
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename... Args>
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
switch (dst.dtype()) {
|
||||
case bool_:
|
||||
copy<SrcT, bool>(src, dst, ctype, args...);
|
||||
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<SrcT, uint8_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<SrcT, uint16_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<SrcT, uint32_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<SrcT, uint64_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<SrcT, int8_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<SrcT, int16_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<SrcT, int32_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<SrcT, int64_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<SrcT, float16_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<SrcT, float>(src, dst, ctype, args...);
|
||||
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<SrcT, complex64_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -338,46 +338,46 @@ inline void copy_inplace_dispatch(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Args... args) {
|
||||
Args&&... args) {
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
copy<bool>(src, dst, ctype, args...);
|
||||
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<uint8_t>(src, dst, ctype, args...);
|
||||
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<uint16_t>(src, dst, ctype, args...);
|
||||
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<uint32_t>(src, dst, ctype, args...);
|
||||
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<uint64_t>(src, dst, ctype, args...);
|
||||
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<int8_t>(src, dst, ctype, args...);
|
||||
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<int16_t>(src, dst, ctype, args...);
|
||||
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<int32_t>(src, dst, ctype, args...);
|
||||
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<int64_t>(src, dst, ctype, args...);
|
||||
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<float16_t>(src, dst, ctype, args...);
|
||||
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<float>(src, dst, ctype, args...);
|
||||
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<bfloat16_t>(src, dst, ctype, args...);
|
||||
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<complex64_t>(src, dst, ctype, args...);
|
||||
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@@ -57,6 +57,7 @@ DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(Expm1)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Full)
|
||||
|
@@ -241,6 +241,13 @@ struct Exp {
|
||||
}
|
||||
};
|
||||
|
||||
struct Expm1 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return expm1(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@@ -22,7 +22,7 @@ namespace mlx::core {
|
||||
void Abs::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (is_unsigned(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), unsignedinteger)) {
|
||||
// No-op for unsigned types
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
@@ -37,7 +37,7 @@ void Arange::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcCos());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -49,7 +49,7 @@ void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcCosh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -61,7 +61,7 @@ void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcSin());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -73,7 +73,7 @@ void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcSinh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -85,7 +85,7 @@ void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcTan());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -97,7 +97,7 @@ void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcTanh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -171,7 +171,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Ceil());
|
||||
} else {
|
||||
// No-op integer types
|
||||
@@ -211,7 +211,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Cos());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -223,7 +223,7 @@ void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Cosh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -350,7 +350,7 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Exp());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -359,10 +359,22 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Expm1::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Expm1());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[expm1] Cannot exponentiate elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Floor());
|
||||
} else {
|
||||
// No-op integer types
|
||||
@@ -388,7 +400,7 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_fp(in, out, detail::Log());
|
||||
@@ -410,7 +422,7 @@ void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Log1p::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Log1p());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -597,7 +609,7 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Round());
|
||||
} else {
|
||||
// No-op integer types
|
||||
@@ -608,7 +620,7 @@ void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sigmoid());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -630,7 +642,7 @@ void Sign::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Sin::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sin());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -642,7 +654,7 @@ void Sin::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sinh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -850,7 +862,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Tan());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -862,7 +874,7 @@ void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Tanh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
|
@@ -6,8 +6,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
enum ReductionOpType {
|
||||
// Self-explanatory. Read everything and produce 1 output.
|
||||
ContiguousAllReduce,
|
||||
@@ -38,6 +36,21 @@ enum ReductionOpType {
|
||||
GeneralReduce
|
||||
};
|
||||
|
||||
struct ReductionPlan {
|
||||
ReductionOpType type;
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
|
||||
ReductionPlan(
|
||||
ReductionOpType type_,
|
||||
std::vector<int> shape_,
|
||||
std::vector<size_t> strides_)
|
||||
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// Helper for the ndimensional strided loop
|
||||
// Should this be in utils?
|
||||
inline void nd_loop(
|
||||
@@ -110,19 +123,6 @@ struct DefaultContiguousReduce {
|
||||
}
|
||||
};
|
||||
|
||||
struct ReductionPlan {
|
||||
ReductionOpType type;
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
|
||||
ReductionPlan(
|
||||
ReductionOpType type_,
|
||||
std::vector<int> shape_,
|
||||
std::vector<size_t> strides_)
|
||||
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
|
@@ -222,7 +222,7 @@ void scan_dispatch(
|
||||
}
|
||||
case Scan::Min: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
|
||||
auto init = (is_floating_point(input.dtype()))
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
? static_cast<U>(std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
@@ -232,7 +232,7 @@ void scan_dispatch(
|
||||
}
|
||||
case Scan::Max: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
||||
auto init = (is_floating_point(input.dtype()))
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@@ -10,7 +10,7 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename AccT>
|
||||
void softmax(const array& in, array& out) {
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
@@ -22,26 +22,36 @@ void softmax(const array& in, array& out) {
|
||||
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
T maximum = *current_in_ptr;
|
||||
AccT maximum = *current_in_ptr;
|
||||
for (int j = 0; j < N; j++, current_in_ptr++) {
|
||||
maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum;
|
||||
maximum = (maximum < *current_in_ptr) ? static_cast<AccT>(*current_in_ptr)
|
||||
: maximum;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
T normalizer = 0;
|
||||
AccT normalizer = 0;
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
|
||||
T expv = std::exp(*current_in_ptr - maximum);
|
||||
AccT expv = std::exp(*current_in_ptr - maximum);
|
||||
normalizer += expv;
|
||||
*current_out_ptr = expv;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = expv;
|
||||
}
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_in_ptr = in_ptr;
|
||||
current_out_ptr = out_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++) {
|
||||
*current_out_ptr *= normalizer;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
auto v = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(v * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -91,13 +101,21 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float>(in, out);
|
||||
softmax<float, float>(in, out);
|
||||
break;
|
||||
case float16:
|
||||
softmax<float16_t>(in, out);
|
||||
if (precise_) {
|
||||
softmax<float16_t, float>(in, out);
|
||||
} else {
|
||||
softmax<float16_t, float16_t>(in, out);
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
softmax<bfloat16_t>(in, out);
|
||||
if (precise_) {
|
||||
softmax<bfloat16_t, float>(in, out);
|
||||
} else {
|
||||
softmax<bfloat16_t, bfloat16_t>(in, out);
|
||||
}
|
||||
break;
|
||||
case complex64:
|
||||
throw std::invalid_argument(
|
||||
|
@@ -89,9 +89,8 @@ collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
return collapse_contiguous_dims(xs[0].shape(), strides);
|
||||
}
|
||||
|
||||
template <typename... Arrays>
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(Arrays... xs) {
|
||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||
inline auto collapse_contiguous_dims(Arrays&&... xs) {
|
||||
return collapse_contiguous_dims(
|
||||
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include "mlx/backend/metal/allocator.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
|
||||
#include <mach/vm_page_size.h>
|
||||
#include <unistd.h>
|
||||
|
@@ -229,14 +229,7 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& output_shape = outputs[0].shape();
|
||||
bool contiguous = true;
|
||||
for (auto& x : inputs) {
|
||||
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
|
||||
!is_scalar(x)) {
|
||||
contiguous = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
bool contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
@@ -296,7 +289,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Put the inputs in
|
||||
@@ -307,7 +300,7 @@ void Compiled::eval_gpu(
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
compute_encoder.set_input_array(x, cnt++);
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
compute_encoder->setBytes(
|
||||
strides[stride_idx].data(),
|
||||
@@ -317,32 +310,12 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
{
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o].move_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, true);
|
||||
|
||||
// Put the outputs in
|
||||
for (auto& x : outputs) {
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
compute_encoder.set_output_array(x, cnt++);
|
||||
}
|
||||
|
||||
// Put the output shape and strides in
|
||||
|
@@ -41,12 +41,12 @@ void explicit_gemm_conv_ND_gpu(
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, in_unfolded, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(in_unfolded, 1);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
|
||||
@@ -140,7 +140,7 @@ void slow_conv_2D_gpu(
|
||||
<< "_tm" << tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -153,9 +153,9 @@ void slow_conv_2D_gpu(
|
||||
MTL::Size group_dims = MTL::Size(bm, bn, 1);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
@@ -241,7 +241,7 @@ void implicit_gemm_conv_2D_gpu(
|
||||
<< "_filter_" << (small_filter ? 's' : 'l');
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -254,9 +254,9 @@ void implicit_gemm_conv_2D_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
@@ -394,7 +394,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -408,9 +408,9 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
@@ -511,12 +511,12 @@ void winograd_conv_2D_gpu(
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, wt, 0);
|
||||
set_array_buffer(compute_encoder, filt_wg, 1);
|
||||
compute_encoder.set_input_array(wt, 0);
|
||||
compute_encoder.set_output_array(filt_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(&C_c, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&O_c, sizeof(int), 3);
|
||||
@@ -539,12 +539,12 @@ void winograd_conv_2D_gpu(
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in_padded, 0);
|
||||
set_array_buffer(compute_encoder, inp_wg, 1);
|
||||
compute_encoder.set_input_array(in_padded, 0);
|
||||
compute_encoder.set_output_array(inp_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
@@ -587,12 +587,12 @@ void winograd_conv_2D_gpu(
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, out_wg, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(out_wg, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
|
@@ -83,15 +83,15 @@ void copy_gpu_inplace(
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_in = in.data_shared_ptr() == nullptr;
|
||||
|
||||
inp_offset *= size_of(in.dtype());
|
||||
out_offset *= size_of(out.dtype());
|
||||
|
||||
set_array_buffer(compute_encoder, donate_in ? out : in, inp_offset, 0);
|
||||
set_array_buffer(compute_encoder, out, out_offset, 1);
|
||||
compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset);
|
||||
compute_encoder.set_output_array(out, 1, out_offset);
|
||||
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
int ndim = shape.size();
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <cstdlib>
|
||||
@@ -11,7 +11,9 @@
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
@@ -145,6 +147,7 @@ void Device::new_queue(int index) {
|
||||
// We lock this as a critical section for safety
|
||||
const std::lock_guard<std::mutex> lock(mtx_);
|
||||
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
||||
debug_set_stream_queue_label(q, index);
|
||||
if (!q) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Failed to make new command queue.");
|
||||
@@ -203,14 +206,15 @@ void Device::end_encoding(int index) {
|
||||
}
|
||||
}
|
||||
|
||||
MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
|
||||
CommandEncoder& Device::get_command_encoder(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit == encoder_map_.end()) {
|
||||
auto cb = get_command_buffer(index);
|
||||
auto compute_encoder = cb->computeCommandEncoder();
|
||||
auto compute_encoder =
|
||||
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
compute_encoder->retain();
|
||||
eit = encoder_map_.insert({index, compute_encoder}).first;
|
||||
eit = encoder_map_.emplace(index, CommandEncoder{compute_encoder}).first;
|
||||
}
|
||||
return eit->second;
|
||||
}
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -7,10 +7,12 @@
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
@@ -34,6 +36,69 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
using MTLFCList =
|
||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(MTL::ComputeCommandEncoder* enc)
|
||||
: enc(enc), concurrent(false){};
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
struct ConcurrentContext {
|
||||
ConcurrentContext(CommandEncoder& enc) : enc(enc) {
|
||||
enc.concurrent = true;
|
||||
}
|
||||
~ConcurrentContext() {
|
||||
enc.concurrent = false;
|
||||
enc.outputs.insert(
|
||||
enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
|
||||
enc.concurrent_outputs.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
CommandEncoder& enc;
|
||||
};
|
||||
|
||||
MTL::ComputeCommandEncoder* operator->() {
|
||||
return enc;
|
||||
}
|
||||
|
||||
void set_input_array(const array& a, int idx, int offset = 0) {
|
||||
auto r_buf =
|
||||
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
||||
// Insert a barrier
|
||||
enc->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs.erase(it);
|
||||
}
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
void set_output_array(array& a, int idx, int offset = 0) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
if (concurrent) {
|
||||
concurrent_outputs.insert(buf);
|
||||
} else {
|
||||
outputs.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
ConcurrentContext start_concurrent() {
|
||||
return ConcurrentContext(*this);
|
||||
}
|
||||
|
||||
private:
|
||||
MTL::ComputeCommandEncoder* enc;
|
||||
bool concurrent;
|
||||
std::unordered_set<MTL::Resource*> outputs;
|
||||
std::unordered_set<MTL::Resource*> concurrent_outputs;
|
||||
};
|
||||
|
||||
class Device {
|
||||
public:
|
||||
Device();
|
||||
@@ -51,7 +116,7 @@ class Device {
|
||||
int get_command_buffer_ops(int index);
|
||||
void increment_command_buffer_ops(int index);
|
||||
void commit_command_buffer(int index);
|
||||
MTL::ComputeCommandEncoder* get_command_encoder(int index);
|
||||
CommandEncoder& get_command_encoder(int index);
|
||||
void end_encoding(int index);
|
||||
|
||||
void register_library(
|
||||
@@ -132,7 +197,7 @@ class Device {
|
||||
MTL::Device* device_;
|
||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||
std::unordered_map<int32_t, MTL::ComputeCommandEncoder*> encoder_map_;
|
||||
std::unordered_map<int32_t, CommandEncoder> encoder_map_;
|
||||
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||
std::mutex mtx_;
|
||||
|
@@ -49,7 +49,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kname << "_" << idx_ndim;
|
||||
}
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -81,8 +81,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, src, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(src, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
// Set source info
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
|
||||
@@ -103,7 +103,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
@@ -183,7 +183,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
kname << "_" << nidx;
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
||||
auto& upd = inputs.back();
|
||||
@@ -192,8 +192,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(upd, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set update info
|
||||
uint upd_ndim = upd.ndim();
|
||||
@@ -210,7 +210,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
@@ -280,7 +280,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
|
@@ -7,6 +7,7 @@ set(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||
@@ -37,11 +38,17 @@ set(
|
||||
)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS}
|
||||
-gline-tables-only
|
||||
-frecord-sources)
|
||||
endif()
|
||||
add_custom_command(
|
||||
COMMAND xcrun -sdk macosx metal -Wall -Wextra
|
||||
-fno-fast-math
|
||||
-c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR}
|
||||
COMMAND xcrun -sdk macosx metal
|
||||
${METAL_FLAGS}
|
||||
-c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR}
|
||||
-o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS}
|
||||
OUTPUT ${TARGET}.air
|
||||
|
89
mlx/backend/metal/kernels/expm1f.h
Normal file
89
mlx/backend/metal/kernels/expm1f.h
Normal file
@@ -0,0 +1,89 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
// Original license copied below:
|
||||
// Copyright (c) 2015-2023 Norbert Juffa
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions
|
||||
// are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright
|
||||
// notice, this list of conditions and the following disclaimer in the
|
||||
// documentation and/or other materials provided with the distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
/* Compute exponential base e minus 1. Maximum ulp error = 0.997458
|
||||
|
||||
i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1.
|
||||
Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5).
|
||||
With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy,
|
||||
when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r.
|
||||
|
||||
NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2)
|
||||
*/
|
||||
float expm1f_scaled_unchecked(float a, float b) {
|
||||
float f, j, r, s, t, u, v, x, y;
|
||||
int i;
|
||||
|
||||
// exp(a) = 2**i * exp(f); i = rintf (a / log(2))
|
||||
j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23
|
||||
j = j - 12582912.0f; // 0x1.8p23
|
||||
i = (int)j;
|
||||
f = fma(j, -6.93145752e-1f, a);
|
||||
|
||||
// approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
|
||||
s = f * f;
|
||||
if (a == 0.0f)
|
||||
s = a; // ensure -0 is passed through
|
||||
// err = 0.997458 ulp1 = 11081805
|
||||
r = 1.97350979e-4f; // 0x1.9de000p-13
|
||||
r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10
|
||||
r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7
|
||||
r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5
|
||||
r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3
|
||||
r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2
|
||||
u = (j == 1) ? (f + 0.5f) : f;
|
||||
v = fma(r, s, u);
|
||||
s = 0.5f * b;
|
||||
t = ldexp(s, i);
|
||||
y = t - s;
|
||||
x = (t - y) - s; // double-float canonicalization of difference
|
||||
r = fma(v, t, x) + y;
|
||||
r = r + r;
|
||||
if (j == 0)
|
||||
r = v;
|
||||
if (j == 1)
|
||||
r = v + v;
|
||||
return r;
|
||||
}
|
||||
|
||||
/* Compute exponential base e minus 1. max ulp err = 0.99746 */
|
||||
float expm1f(float a) {
|
||||
float r;
|
||||
|
||||
r = expm1f_scaled_unchecked(a, 1.0f);
|
||||
/* handle severe overflow and underflow */
|
||||
if (abs(a - 1.0f) > 88.0f) {
|
||||
r = fma(r, r, -1.0f);
|
||||
}
|
||||
return r;
|
||||
}
|
@@ -205,39 +205,341 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_layer_norm_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the computation and accumulators
|
||||
float thread_x[N_READS];
|
||||
float thread_w[N_READS];
|
||||
float thread_g[N_READS];
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float sumwg = 0;
|
||||
float sumwgx = 0;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumwg[SIMD_SIZE];
|
||||
threadgroup float local_sumwgx[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meanwg[1];
|
||||
threadgroup float local_meanwgx[1];
|
||||
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[i * w_stride];
|
||||
thread_g[i] = g[i];
|
||||
float wg = thread_w[i] * thread_g[i];
|
||||
sumx += thread_x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumwg += wg;
|
||||
sumwgx += wg * thread_x[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[i * w_stride];
|
||||
thread_g[i] = g[i];
|
||||
float wg = thread_w[i] * thread_g[i];
|
||||
sumx += thread_x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumwg += wg;
|
||||
sumwgx += wg * thread_x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumwg = simd_sum(sumwg);
|
||||
sumwgx = simd_sum(sumwgx);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumwg[simd_lane_id] = 0;
|
||||
local_sumwgx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumwg[simd_group_id] = sumwg;
|
||||
local_sumwgx[simd_group_id] = sumwgx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
||||
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
local_meanwg[0] = sumwg / axis_size;
|
||||
local_meanwgx[0] = sumwgx / axis_size;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float meanwg = local_meanwg[0];
|
||||
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
||||
float normalizer2 = normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_layer_norm_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the accumulators
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float sumwg = 0;
|
||||
float sumwgx = 0;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumwg[SIMD_SIZE];
|
||||
threadgroup float local_sumwgx[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meanwg[1];
|
||||
threadgroup float local_meanwgx[1];
|
||||
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
float wg = wi * gi;
|
||||
sumx += xi;
|
||||
sumx2 += xi * xi;
|
||||
sumwg += wg;
|
||||
sumwgx += wg * xi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
float wg = wi * gi;
|
||||
sumx += xi;
|
||||
sumx2 += xi * xi;
|
||||
sumwg += wg;
|
||||
sumwgx += wg * xi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumwg = simd_sum(sumwg);
|
||||
sumwgx = simd_sum(sumwgx);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumwg[simd_lane_id] = 0;
|
||||
local_sumwgx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumwg[simd_group_id] = sumwg;
|
||||
local_sumwgx[simd_group_id] = sumwgx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
||||
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
local_meanwg[0] = sumwg / axis_size;
|
||||
local_meanwgx[0] = sumwgx / axis_size;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float meanwg = local_meanwg[0];
|
||||
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
||||
float normalizer2 = normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (x[i + r] - mean) * normalizer;
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
||||
xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = (x[i + r] - mean) * normalizer;
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
||||
xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_layer_norm_single_row(name, itype) \
|
||||
template [[host_name("layer_norm" #name)]] [[kernel]] void \
|
||||
layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
#define instantiate_layer_norm_single_row(name, itype) \
|
||||
template [[host_name("layer_norm" #name)]] [[kernel]] void \
|
||||
layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \
|
||||
vjp_layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm_looped(name, itype) \
|
||||
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
|
||||
layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
#define instantiate_layer_norm_looped(name, itype) \
|
||||
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
|
||||
layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \
|
||||
vjp_layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gb, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm(name, itype) \
|
||||
|
@@ -15,14 +15,6 @@ using namespace metal;
|
||||
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
template <typename T> struct AccT {
|
||||
typedef T acc_t;
|
||||
};
|
||||
|
||||
template <> struct AccT<bfloat16_t> {
|
||||
typedef float acc_t;
|
||||
};
|
||||
|
||||
|
||||
template <typename T, typename U, int values_per_thread, int bits>
|
||||
inline U load_vector(const device T *x, thread U *x_thread) {
|
||||
@@ -60,6 +52,51 @@ inline U load_vector(const device T *x, thread U *x_thread) {
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename T, typename U, int values_per_thread, int bits>
|
||||
inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U sum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < N; i += 4) {
|
||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i+1] = x[i+1] / 4.0f;
|
||||
x_thread[i+2] = x[i+2] / 16.0f;
|
||||
x_thread[i+3] = x[i+3] / 64.0f;
|
||||
}
|
||||
for (int i=N; i<values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
for (int i = 0; i < N; i += 4) {
|
||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i+1] = x[i+1] / 16.0f;
|
||||
x_thread[i+2] = x[i+2] / 256.0f;
|
||||
x_thread[i+3] = x[i+3] / 4096.0f;
|
||||
}
|
||||
for (int i=N; i<values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
sum += x[i];
|
||||
x_thread[i] = x[i];
|
||||
}
|
||||
for (int i=N; i<values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
@@ -96,6 +133,74 @@ inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias
|
||||
return scale * accum + sum * bias;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum, int N) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U accum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum += (
|
||||
x_thread[4*i] * (w[i] & 0x03)
|
||||
+ x_thread[4*i+1] * (w[i] & 0x0c)
|
||||
+ x_thread[4*i+2] * (w[i] & 0x30)
|
||||
+ x_thread[4*i+3] * (w[i] & 0xc0));
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum += (
|
||||
x_thread[4*i] * (ws[i] & 0x000f)
|
||||
+ x_thread[4*i+1] * (ws[i] & 0x00f0)
|
||||
+ x_thread[4*i+2] * (ws[i] & 0x0f00)
|
||||
+ x_thread[4*i+3] * (ws[i] & 0xf000));
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
accum += x_thread[i] * w[i];
|
||||
}
|
||||
}
|
||||
|
||||
return scale * accum + sum * bias;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
if (bits == 2) {
|
||||
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
result[4*i] += x * (s[0] * (w[i] & 0x03) + bias);
|
||||
result[4*i+1] += x * (s[1] * (w[i] & 0x0c) + bias);
|
||||
result[4*i+2] += x * (s[2] * (w[i] & 0x30) + bias);
|
||||
result[4*i+3] += x * (s[3] * (w[i] & 0xc0) + bias);
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const thread uint16_t* ws = (const thread uint16_t*)w;
|
||||
U s[4] = {scale, scale / 16.0f, scale / 256.0f, scale / 4096.0f};
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
result[4*i] += x * (s[0] * (ws[i] & 0x000f) + bias);
|
||||
result[4*i+1] += x * (s[1] * (ws[i] & 0x00f0) + bias);
|
||||
result[4*i+2] += x * (s[2] * (ws[i] & 0x0f00) + bias);
|
||||
result[4*i+3] += x * (s[3] * (ws[i] & 0xf000) + bias);
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < values_per_thread; i++) {
|
||||
result[i] += x * (scale * w[i] + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, int packs_per_thread>
|
||||
[[kernel]] void qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
@@ -204,7 +309,8 @@ template <typename T, const int group_size, const int bits>
|
||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.z * out_vec_size + out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
int k = 0;
|
||||
for (; k < in_vec_size-block_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
@@ -222,6 +328,18 @@ template <typename T, const int group_size, const int bits>
|
||||
biases += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
|
||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
@@ -239,7 +357,8 @@ template <typename T, const int group_size, const int bits>
|
||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.z * out_vec_size + used_out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
int k = 0;
|
||||
for (; k < in_vec_size-block_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
@@ -257,6 +376,18 @@ template <typename T, const int group_size, const int bits>
|
||||
biases += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
|
||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot_safe<U, values_per_thread, bits>(wl, x_thread, s, b, sum, remaining);
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
@@ -268,7 +399,7 @@ template <typename T, const int group_size, const int bits>
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void qvm(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
@@ -278,39 +409,28 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
static_assert(BM == SIMD_SIZE, "qvm expects BM to be equal to SIMD_SIZE");
|
||||
static_assert(BN == BM, "qvm expects a block size of 32x32");
|
||||
constexpr int num_simdgroups = 8;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int blocksize = SIMD_SIZE;
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_int = 32 / bits;
|
||||
constexpr int colgroup = BN * el_per_int;
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
|
||||
typedef typename AccT<T>::acc_t U;
|
||||
threadgroup U scales_block[BM * groups_per_block];
|
||||
threadgroup U biases_block[BM * groups_per_block];
|
||||
threadgroup U x_block[BM];
|
||||
typedef float U;
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread U result[el_per_int] = {0};
|
||||
thread U result[pack_factor] = {0};
|
||||
thread U scale = 1;
|
||||
thread U bias = 0;
|
||||
thread U x_local = 0;
|
||||
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size / el_per_int;
|
||||
const int out_vec_size_w = out_vec_size / pack_factor;
|
||||
const int out_vec_size_g = out_vec_size / group_size;
|
||||
int out_col_start = tid.y * (BN * el_per_int);
|
||||
int out_col = out_col_start + simd_gid * el_per_int;
|
||||
w += out_col / el_per_int;
|
||||
scales += out_col_start / group_size;
|
||||
biases += out_col_start / group_size;
|
||||
int out_col = tid.y * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
|
||||
w += out_col / pack_factor;
|
||||
scales += out_col / group_size;
|
||||
biases += out_col / group_size;
|
||||
x += tid.z * in_vec_size;
|
||||
y += tid.z * out_vec_size + out_col;
|
||||
|
||||
@@ -318,53 +438,39 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
return;
|
||||
}
|
||||
|
||||
// Loop over in_vec in blocks of colgroup
|
||||
for (int i=0; i<in_vec_size; i+=BM) {
|
||||
int offset_lid = simd_lid + i;
|
||||
int offset_gid = simd_gid + i;
|
||||
bool thread_in_bounds = offset_lid < in_vec_size;
|
||||
bool group_in_bounds = offset_gid < in_vec_size;
|
||||
// Loop over in_vec in blocks of blocksize
|
||||
int i = 0;
|
||||
for (; i + blocksize <= in_vec_size; i += blocksize) {
|
||||
x_local = x[i + simd_lid];
|
||||
scale = scales[(i + simd_lid) * out_vec_size_g];
|
||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
|
||||
// Load the vec to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
x_block[simd_lid] = (thread_in_bounds) ? x[offset_lid] : 0;
|
||||
}
|
||||
|
||||
// Load the scales and biases to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_lid < groups_per_block && group_in_bounds) {
|
||||
scales_block[simd_gid * groups_per_block + simd_lid] = scales[offset_gid * out_vec_size_g + simd_lid];
|
||||
biases_block[simd_gid * groups_per_block + simd_lid] = biases[offset_gid * out_vec_size_g + simd_lid];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load in_vec, scale, bias to registers
|
||||
x_local = x_block[simd_lid];
|
||||
scale = scales_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
|
||||
// Load the matrix elements
|
||||
w_local = (thread_in_bounds) ? w[offset_lid * out_vec_size_w] : 0;
|
||||
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
result[k] += (scale * static_cast<U>(w_local & bitmask) + bias) * x_local;
|
||||
w_local >>= bits;
|
||||
}
|
||||
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
|
||||
}
|
||||
if (static_cast<int>(i + simd_lid) < in_vec_size) {
|
||||
x_local = x[i + simd_lid];
|
||||
scale = scales[(i + simd_lid) * out_vec_size_g];
|
||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
} else {
|
||||
x_local = 0;
|
||||
scale = 0;
|
||||
bias = 0;
|
||||
w_local = 0;
|
||||
}
|
||||
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
|
||||
|
||||
// Accumulate in the simdgroup
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
for (int k=0; k<pack_factor; k++) {
|
||||
result[k] = simd_sum(result[k]);
|
||||
}
|
||||
|
||||
// Store the result
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
for (int k=0; k<pack_factor; k++) {
|
||||
y[k] = static_cast<T>(result[k]);
|
||||
}
|
||||
}
|
||||
@@ -414,6 +520,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
const int K_g = K / group_size;
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
|
||||
x += y_row * K;
|
||||
w += y_col * K_w;
|
||||
scales += y_col * K_g;
|
||||
@@ -466,7 +573,10 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||
|
||||
if (y_row + offset_row < N) {
|
||||
// y_col corresponds to the row of the weight matrix and added to
|
||||
// offset_row it should be less than the total number of rows
|
||||
// otherwise skip.
|
||||
if (y_col + offset_row < N) {
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
@@ -619,7 +729,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
|
||||
|
||||
if (y_row + offset_row < K) {
|
||||
if (k + offset_row < K) {
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
@@ -738,7 +848,7 @@ instantiate_qmv_types( 32, 8)
|
||||
|
||||
#define instantiate_qvm(name, itype, group_size, bits) \
|
||||
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qvm<itype, 32, 32, group_size, bits>( \
|
||||
[[kernel]] void qvm<itype, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
@@ -747,7 +857,6 @@ instantiate_qmv_types( 32, 8)
|
||||
const constant int& in_vec_size [[buffer(5)]], \
|
||||
const constant int& out_vec_size [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
|
@@ -150,6 +150,216 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_rms_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the computation and accumulators
|
||||
float thread_x[N_READS];
|
||||
float thread_w[N_READS];
|
||||
float thread_g[N_READS];
|
||||
float sumx2 = 0;
|
||||
float sumgwx = 0;
|
||||
|
||||
// Allocate shared memory to implement the reduction
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumgwx[SIMD_SIZE];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meangwx[1];
|
||||
|
||||
// Read and accumulate locally
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[w_stride * i];
|
||||
thread_g[i] = g[i];
|
||||
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumgwx += thread_x[i] * thread_w[i] * thread_g[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[w_stride * i];
|
||||
thread_g[i] = g[i];
|
||||
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumgwx += thread_x[i] * thread_w[i] * thread_g[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate across threads
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumgwx = simd_sum(sumgwx);
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumgwx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumgwx[simd_group_id] = sumgwx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumgwx = simd_sum(local_sumgwx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_meangwx[0] = sumgwx / axis_size;
|
||||
local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float meangwx = local_meangwx[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_rms_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the accumulators
|
||||
float sumx2 = 0;
|
||||
float sumgwx = 0;
|
||||
|
||||
// Allocate shared memory to implement the reduction
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumgwx[SIMD_SIZE];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meangwx[1];
|
||||
|
||||
// Read and accumulate locally
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
sumx2 += xi * xi;
|
||||
sumgwx += xi * wi * gi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
sumx2 += xi * xi;
|
||||
sumgwx += xi * wi * gi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate across threads
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumgwx = simd_sum(sumgwx);
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumgwx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumgwx[simd_group_id] = sumgwx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumgwx = simd_sum(local_sumgwx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_meangwx[0] = sumgwx / axis_size;
|
||||
local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float meangwx = local_meangwx[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rms_single_row(name, itype) \
|
||||
template [[host_name("rms" #name)]] [[kernel]] void \
|
||||
@@ -165,25 +375,56 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms_looped(name, itype) \
|
||||
template [[host_name("rms_looped" #name)]] [[kernel]] void \
|
||||
rms_looped<itype>( \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
\
|
||||
template [[host_name("vjp_rms" #name)]] [[kernel]] void \
|
||||
vjp_rms_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms_looped(name, itype) \
|
||||
template [[host_name("rms_looped" #name)]] [[kernel]] void \
|
||||
rms_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
\
|
||||
template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \
|
||||
vjp_rms_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms(name, itype) \
|
||||
instantiate_rms_single_row(name, itype) \
|
||||
instantiate_rms_looped(name, itype)
|
||||
|
@@ -5,7 +5,7 @@
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, bool traditional>
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope(
|
||||
const device T *in [[buffer(0)]],
|
||||
device T * out [[buffer(1)]],
|
||||
@@ -43,15 +43,22 @@ template <typename T, bool traditional>
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float rx1 = x1 * costheta - x2 * sintheta;
|
||||
float rx2 = x1 * sintheta + x2 * costheta;
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
#define instantiate_rope(name, type, traditional) \
|
||||
#define instantiate_rope(name, type, traditional, forward) \
|
||||
template [[host_name("rope_" #name)]] \
|
||||
[[kernel]] void rope<type, traditional>( \
|
||||
[[kernel]] void rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const size_t strides[3], \
|
||||
@@ -62,9 +69,15 @@ template <typename T, bool traditional>
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
instantiate_rope(traditional_float16, half, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
|
||||
instantiate_rope(traditional_float32, float, true)
|
||||
instantiate_rope(float16, half, false)
|
||||
instantiate_rope(bfloat16, bfloat16_t, false)
|
||||
instantiate_rope(float32, float, false)
|
||||
instantiate_rope(traditional_float16, half, true, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
||||
instantiate_rope(traditional_float32, float, true, true)
|
||||
instantiate_rope(float16, half, false, true)
|
||||
instantiate_rope(bfloat16, bfloat16_t, false, true)
|
||||
instantiate_rope(float32, float, false, true)
|
||||
instantiate_rope(vjp_traditional_float16, half, true, false)
|
||||
instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
|
||||
instantiate_rope(vjp_traditional_float32, float, true, false)
|
||||
instantiate_rope(vjp_float16, half, false, false)
|
||||
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
|
||||
instantiate_rope(vjp_float32, float, false, false)
|
||||
|
@@ -451,7 +451,7 @@ instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSu
|
||||
//instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2)
|
||||
instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4)
|
||||
instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum)
|
||||
//instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4)
|
||||
instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4)
|
||||
@@ -464,7 +464,7 @@ instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumP
|
||||
//instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2)
|
||||
instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4)
|
||||
instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
|
||||
instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd)
|
||||
//instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4)
|
||||
instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4)
|
||||
@@ -477,7 +477,7 @@ instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMa
|
||||
//instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2)
|
||||
instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4)
|
||||
instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4)
|
||||
//instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
|
||||
instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
|
||||
//instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax)
|
||||
//instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4)
|
||||
instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4)
|
||||
@@ -490,5 +490,5 @@ instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMi
|
||||
//instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2)
|
||||
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
||||
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
||||
//instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin)
|
||||
|
@@ -11,46 +11,48 @@ using namespace metal;
|
||||
|
||||
template <typename T>
|
||||
inline T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause it is gonna be x
|
||||
// will be in (-oo, 0] anyway and subsequently it will be divided by
|
||||
// sum(exp(x_i)).
|
||||
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||
return fast::exp(x);
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_single_row(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
threadgroup T* local_max [[threadgroup(0)]],
|
||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
int lid = _lid;
|
||||
|
||||
T ld[N_READS];
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
AccT ld[N_READS];
|
||||
|
||||
in += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
ld[i] = in[i];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = AccT(in[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] =
|
||||
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min);
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
|
||||
: Limits<AccT>::finite_min;
|
||||
}
|
||||
}
|
||||
if (simd_group_id == 0) {
|
||||
local_max[simd_lane_id] = Limits<T>::finite_min;
|
||||
local_max[simd_lane_id] = Limits<AccT>::finite_min;
|
||||
local_normalizer[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Get the max
|
||||
T maxval = Limits<T>::finite_min;
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < ld[i]) ? ld[i] : maxval;
|
||||
}
|
||||
@@ -69,9 +71,9 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
maxval = local_max[0];
|
||||
|
||||
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
||||
T normalizer = 0;
|
||||
AccT normalizer = 0;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
T exp_x = softmax_exp(ld[i] - maxval);
|
||||
AccT exp_x = softmax_exp(ld[i] - maxval);
|
||||
ld[i] = exp_x;
|
||||
normalizer += exp_x;
|
||||
}
|
||||
@@ -92,25 +94,23 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
// Normalize and write to the output
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[i] = ld[i] * normalizer;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = T(ld[i] * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = ld[i] * normalizer;
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = T(ld[i] * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_looped(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
threadgroup T* local_max [[threadgroup(0)]],
|
||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
@@ -118,22 +118,27 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
in += gid * axis_size;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
// Get the max and the normalizer in one go
|
||||
T prevmax;
|
||||
T maxval = Limits<T>::finite_min;
|
||||
T normalizer = 0;
|
||||
AccT prevmax;
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
T vals[N_READS];
|
||||
AccT vals[N_READS];
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[offset + i];
|
||||
vals[i] = AccT(in[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] =
|
||||
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min);
|
||||
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
|
||||
: Limits<AccT>::finite_min;
|
||||
}
|
||||
}
|
||||
prevmax = maxval;
|
||||
@@ -179,50 +184,66 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (offset + i < axis_size) {
|
||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||
out[offset + i] =
|
||||
T(softmax_exp(in[offset + i] - maxval) * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_softmax_single_row(name, itype) \
|
||||
// clang-format off
|
||||
#define instantiate_softmax(name, itype) \
|
||||
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax_looped(name, itype) \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||
uint gid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax(name, itype) \
|
||||
instantiate_softmax_single_row(name, itype) \
|
||||
instantiate_softmax_looped(name, itype)
|
||||
#define instantiate_softmax_precise(name, itype) \
|
||||
template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
instantiate_softmax(float32, float)
|
||||
instantiate_softmax(float16, half)
|
||||
instantiate_softmax(bfloat16, bfloat16_t)
|
||||
instantiate_softmax_precise(float16, half)
|
||||
instantiate_softmax_precise(bfloat16, bfloat16_t)
|
||||
// clang-format on
|
||||
|
@@ -394,7 +394,7 @@ struct Conv2DWeightBlockLoader {
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
: src_ld(params_ -> wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
|
@@ -244,7 +244,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
: src_ld(params_ -> wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
|
@@ -220,7 +220,7 @@ struct Conv2DWeightBlockLoaderGeneral {
|
||||
const short base_ww_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
: src_ld(params_ -> wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
|
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/expm1f.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
namespace {
|
||||
@@ -183,6 +184,13 @@ struct Exp {
|
||||
}
|
||||
};
|
||||
|
||||
struct Expm1 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(expm1f(static_cast<float>(x)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@@ -71,6 +71,7 @@ instantiate_unary_types(ceil, Ceil)
|
||||
instantiate_unary_float(cos, Cos)
|
||||
instantiate_unary_float(cosh, Cosh)
|
||||
instantiate_unary_float(exp, Exp)
|
||||
instantiate_unary_float(expm1, Expm1)
|
||||
instantiate_unary_types(floor, Floor)
|
||||
instantiate_unary_float(log, Log)
|
||||
instantiate_unary_float(log2, Log2)
|
||||
|
@@ -197,8 +197,8 @@ inline auto collapse_batches(const array& a, const array& b) {
|
||||
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
if (A_bshape != B_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: "
|
||||
<< "A " << a.shape() << ", B " << b.shape() << ".";
|
||||
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
<< a.shape() << ", B " << b.shape() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
@@ -227,9 +227,8 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
std::vector<int> C_bshape{c.shape().begin(), c.shape().end() - 2};
|
||||
if (A_bshape != B_bshape || A_bshape != C_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: "
|
||||
<< "A " << a.shape() << ", B " << b.shape() << ", B " << c.shape()
|
||||
<< ".";
|
||||
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
@@ -332,11 +331,11 @@ void steel_matmul(
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch gemm kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -360,9 +359,9 @@ void steel_matmul(
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
|
||||
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, C_split, 2);
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(C_split, 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
@@ -380,8 +379,8 @@ void steel_matmul(
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, C_split, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(C_split, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
@@ -422,11 +421,11 @@ void steel_matmul(
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -467,9 +466,9 @@ void steel_matmul(
|
||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||
|
||||
// Launch kernel
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
|
||||
@@ -488,7 +487,7 @@ void steel_matmul(
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
@@ -622,7 +621,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby0";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -630,9 +629,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
set_array_buffer(compute_encoder, mat, 0);
|
||||
set_array_buffer(compute_encoder, vec, 1);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
@@ -696,7 +695,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
@@ -834,7 +833,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby1";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -842,10 +841,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
set_array_buffer(compute_encoder, mat, 0);
|
||||
set_array_buffer(compute_encoder, vec, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
@@ -903,11 +902,11 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch gemm kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -931,9 +930,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
|
||||
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, C_split, 2);
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(C_split, 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
@@ -946,12 +945,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, C_split, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(C_split, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
set_array_buffer(compute_encoder, c, 5);
|
||||
compute_encoder.set_input_array(c, 5);
|
||||
compute_encoder->setBytes(&ldc, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&fdc, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 8);
|
||||
@@ -992,12 +991,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned"
|
||||
<< ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby");
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -1045,10 +1044,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());
|
||||
|
||||
// Launch kernel
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 5);
|
||||
|
@@ -1,10 +1,10 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
@@ -74,6 +74,8 @@ std::function<void()> make_task(
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
|
||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||
@@ -86,7 +88,6 @@ std::function<void()> make_task(
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
|
||||
if (p) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
@@ -108,4 +109,31 @@ std::function<void()> make_task(
|
||||
return task;
|
||||
}
|
||||
|
||||
bool start_capture(std::string path, id object) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
auto descriptor = MTL::CaptureDescriptor::alloc()->init();
|
||||
descriptor->setCaptureObject(object);
|
||||
|
||||
if (path.length() > 0) {
|
||||
auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding);
|
||||
auto url = NS::URL::fileURLWithPath(string);
|
||||
descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument);
|
||||
descriptor->setOutputURL(url);
|
||||
}
|
||||
|
||||
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
||||
return manager->startCapture(descriptor, nullptr);
|
||||
}
|
||||
|
||||
bool start_capture(std::string path) {
|
||||
auto& device = metal::device(mlx::core::Device::gpu);
|
||||
return start_capture(path, device.mtl_device());
|
||||
}
|
||||
|
||||
void stop_capture() {
|
||||
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
||||
manager->stopCapture();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -2,15 +2,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
/* Check if the Metal backend is available. */
|
||||
bool is_available();
|
||||
|
||||
/* Get the actively used memory in bytes.
|
||||
@@ -58,12 +54,8 @@ size_t set_memory_limit(size_t limit, bool relaxed = true);
|
||||
* */
|
||||
size_t set_cache_limit(size_t limit);
|
||||
|
||||
void new_stream(Stream stream);
|
||||
std::shared_ptr<void> new_scoped_memory_pool();
|
||||
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
/** Capture a GPU trace, saving it to an absolute file `path` */
|
||||
bool start_capture(std::string path = "");
|
||||
void stop_capture();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
22
mlx/backend/metal/metal_impl.h
Normal file
22
mlx/backend/metal/metal_impl.h
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
void new_stream(Stream stream);
|
||||
std::shared_ptr<void> new_scoped_memory_pool();
|
||||
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
|
||||
} // namespace mlx::core::metal
|
@@ -4,6 +4,7 @@
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
@@ -18,7 +19,7 @@ void RMSNorm::eval_gpu(
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
@@ -27,10 +28,9 @@ void RMSNorm::eval_gpu(
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
@@ -57,7 +57,7 @@ void RMSNorm::eval_gpu(
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
@@ -79,10 +79,10 @@ void RMSNorm::eval_gpu(
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(
|
||||
compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(
|
||||
x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 5);
|
||||
@@ -95,6 +95,113 @@ void RMSNorm::eval_gpu(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void RMSNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
}
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& g = check_input(inputs[2]);
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
|
||||
// Allocate space for the outputs
|
||||
bool x_in_gx = false;
|
||||
bool g_in_gx = false;
|
||||
if (x.is_donatable()) {
|
||||
gx.move_shared_buffer(x);
|
||||
x_in_gx = true;
|
||||
} else if (g.is_donatable()) {
|
||||
gx.move_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and initialize the
|
||||
// gradient accumulator to 0.
|
||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
||||
bool g_in_gw = false;
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
gw_temp.move_shared_buffer(g);
|
||||
g_in_gw = true;
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
{
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
copies.push_back(std::move(zero));
|
||||
}
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
const int looped_limit = RMS_LOOPED_LIMIT;
|
||||
std::string op_name = "vjp_rms";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(gx);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
compute_encoder.set_output_array(gx, 3);
|
||||
compute_encoder.set_output_array(gw_temp, 4);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void LayerNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
@@ -104,7 +211,7 @@ void LayerNorm::eval_gpu(
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
@@ -113,10 +220,9 @@ void LayerNorm::eval_gpu(
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
@@ -144,7 +250,7 @@ void LayerNorm::eval_gpu(
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
@@ -167,11 +273,11 @@ void LayerNorm::eval_gpu(
|
||||
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(
|
||||
compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, b, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
compute_encoder.set_input_array(
|
||||
x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(b, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 4);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6);
|
||||
@@ -182,4 +288,131 @@ void LayerNorm::eval_gpu(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void LayerNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
}
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
const array& g = check_input(inputs[3]);
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
array& gb = outputs[2];
|
||||
|
||||
// Allocate space for the outputs
|
||||
bool x_in_gx = false;
|
||||
bool g_in_gx = false;
|
||||
if (x.is_donatable()) {
|
||||
gx.move_shared_buffer(x);
|
||||
x_in_gx = true;
|
||||
} else if (g.is_donatable()) {
|
||||
gx.move_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and initialize the
|
||||
// gradient accumulator to 0.
|
||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
||||
bool g_in_gw = false;
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
gw_temp.move_shared_buffer(g);
|
||||
g_in_gw = true;
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
{
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
copy_gpu(zero, gb, CopyType::Scalar, s);
|
||||
copies.push_back(std::move(zero));
|
||||
}
|
||||
|
||||
// Finish with the gradient for b in case we had a b
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
if (gb.ndim() == 1 && gb.size() == axis_size) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
g_in_gx ? gx : (g_in_gw ? gw_temp : g),
|
||||
gb,
|
||||
"sum",
|
||||
plan,
|
||||
{0},
|
||||
compute_encoder,
|
||||
d,
|
||||
s);
|
||||
}
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
const int looped_limit = RMS_LOOPED_LIMIT;
|
||||
std::string op_name = "vjp_layer_norm";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(gx);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
compute_encoder.set_output_array(gx, 3);
|
||||
compute_encoder.set_output_array(gw_temp, 4);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
if (gw.ndim() == 1 && gw.size() == axis_size) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
@@ -68,18 +68,18 @@ void binary_op(
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
// - If a is donated it goes to the first output
|
||||
// - If b is donated it goes to the first output if a was not donated
|
||||
// otherwise it goes to the second output
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
set_array_buffer(compute_encoder, donate_a ? outputs[0] : a, 0);
|
||||
set_array_buffer(
|
||||
compute_encoder, donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
|
||||
set_array_buffer(compute_encoder, outputs[0], 2);
|
||||
set_array_buffer(compute_encoder, outputs[1], 3);
|
||||
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
|
||||
compute_encoder.set_input_array(
|
||||
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
|
||||
compute_encoder.set_output_array(outputs[0], 2);
|
||||
compute_encoder.set_output_array(outputs[1], 3);
|
||||
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
@@ -167,13 +167,13 @@ void binary_op(
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
set_array_buffer(compute_encoder, donate_a ? out : a, 0);
|
||||
set_array_buffer(compute_encoder, donate_b ? out : b, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(donate_a ? out : a, 0);
|
||||
compute_encoder.set_input_array(donate_b ? out : b, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
@@ -253,12 +253,12 @@ void ternary_op(
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
if (topt == TernaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
@@ -339,11 +339,11 @@ void unary_op(
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(
|
||||
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
if (!contig) {
|
||||
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
|
||||
compute_encoder->setBytes(
|
||||
@@ -365,7 +365,7 @@ void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void arange_set_scalars(T start, T next, MTL::ComputeCommandEncoder* enc) {
|
||||
void arange_set_scalars(T start, T next, CommandEncoder& enc) {
|
||||
enc->setBytes(&start, sizeof(T), 0);
|
||||
T step = next - start;
|
||||
enc->setBytes(&step, sizeof(T), 1);
|
||||
@@ -384,7 +384,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
switch (out.dtype()) {
|
||||
@@ -427,7 +427,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
|
||||
}
|
||||
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
@@ -487,7 +487,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// ArgReduce
|
||||
int simd_size = 32;
|
||||
int n_reads = 4;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name + type_to_name(in));
|
||||
NS::UInteger thread_group_size = std::min(
|
||||
@@ -502,8 +502,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
if (ndim == 0) {
|
||||
// Pass place holders so metal doesn't complain
|
||||
int shape_ = 0;
|
||||
@@ -552,6 +552,9 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
flags.row_contiguous = false;
|
||||
flags.col_contiguous = false;
|
||||
flags.contiguous = false;
|
||||
auto& d = metal::device(stream().device);
|
||||
auto& compute_encoder = d.get_command_encoder(stream().index);
|
||||
auto concurrent_ctx = compute_encoder.start_concurrent();
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis_] * sizes[i];
|
||||
@@ -615,6 +618,10 @@ void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "exp");
|
||||
}
|
||||
|
||||
void Expm1::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "expm1");
|
||||
}
|
||||
|
||||
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto in = inputs[0];
|
||||
CopyType ctype;
|
||||
@@ -787,10 +794,10 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, keys, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(keys, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&odd, sizeof(bool), 2);
|
||||
compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3);
|
||||
|
||||
@@ -822,7 +829,7 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_op(inputs, out, "round");
|
||||
} else {
|
||||
// No-op integer types
|
||||
|
@@ -48,7 +48,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< bits_ << "_fast";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -57,11 +57,11 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, w, 0);
|
||||
set_array_buffer(compute_encoder, scales, 1);
|
||||
set_array_buffer(compute_encoder, biases, 2);
|
||||
set_array_buffer(compute_encoder, x, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
@@ -75,7 +75,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -84,11 +84,11 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, w, 0);
|
||||
set_array_buffer(compute_encoder, scales, 1);
|
||||
set_array_buffer(compute_encoder, biases, 2);
|
||||
set_array_buffer(compute_encoder, x, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
@@ -102,7 +102,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -114,11 +114,11 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
|
||||
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, scales, 2);
|
||||
set_array_buffer(compute_encoder, biases, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
@@ -133,20 +133,20 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = std::min(32, O);
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, scales, 2);
|
||||
set_array_buffer(compute_encoder, biases, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
@@ -160,7 +160,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -179,11 +179,11 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, scales, 2);
|
||||
set_array_buffer(compute_encoder, biases, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
|
@@ -4,10 +4,10 @@
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -18,8 +18,6 @@ namespace mlx::core {
|
||||
// Case wise reduce dispatch
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
inline auto safe_div(size_t n, size_t m) {
|
||||
return m == 0 ? 0 : (n + m - 1) / m;
|
||||
}
|
||||
@@ -37,7 +35,7 @@ void all_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
Dtype out_dtype = out.dtype();
|
||||
@@ -73,8 +71,8 @@ void all_reduce_dispatch(
|
||||
|
||||
// Encode buffers and dispatch
|
||||
if (is_out_64b_int == false || n_thread_groups == 1) {
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
@@ -87,14 +85,14 @@ void all_reduce_dispatch(
|
||||
std::vector<array> intermediates = {intermediate};
|
||||
|
||||
// First dispatch
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, intermediate, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(intermediate, 1);
|
||||
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Second pass to reduce intermediate reduction results written to DRAM
|
||||
set_array_buffer(compute_encoder, intermediate, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2);
|
||||
|
||||
mod_in_size = (intermediate_size + n_reads - 1) / n_reads;
|
||||
@@ -125,7 +123,7 @@ void row_reduce_general_dispatch(
|
||||
const std::string& op_name,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
Dtype out_dtype = out.dtype();
|
||||
@@ -210,8 +208,8 @@ void row_reduce_general_dispatch(
|
||||
// Dispatch kernel
|
||||
if (!is_out_64b_int || non_row_reductions == 1) {
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4);
|
||||
@@ -232,8 +230,8 @@ void row_reduce_general_dispatch(
|
||||
std::vector<array> intermediates = {intermediate};
|
||||
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, intermediate, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(intermediate, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4);
|
||||
@@ -260,8 +258,8 @@ void row_reduce_general_dispatch(
|
||||
ndim = new_shape.size();
|
||||
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, intermediate, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4);
|
||||
@@ -303,7 +301,7 @@ void strided_reduce_general_dispatch(
|
||||
const std::string& op_name,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
Dtype out_dtype = out.dtype();
|
||||
@@ -351,8 +349,8 @@ void strided_reduce_general_dispatch(
|
||||
}
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||
@@ -417,8 +415,8 @@ void strided_reduce_general_dispatch(
|
||||
|
||||
if (is_out_64b_int == false) {
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||
@@ -452,8 +450,8 @@ void strided_reduce_general_dispatch(
|
||||
std::vector<array> intermediates = {intermediate};
|
||||
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, intermediate, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(intermediate, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||
@@ -496,8 +494,8 @@ void strided_reduce_general_dispatch(
|
||||
"row_reduce_general_no_atomics_" + op_name +
|
||||
type_to_name(intermediate));
|
||||
compute_encoder->setComputePipelineState(row_reduce_kernel);
|
||||
set_array_buffer(compute_encoder, intermediate, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 4);
|
||||
@@ -534,8 +532,6 @@ void strided_reduce_general_dispatch(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// Main reduce dispatch
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
@@ -577,7 +573,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Initialize output
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel("i" + op_name + type_to_name(out));
|
||||
size_t nthreads = out.size();
|
||||
@@ -588,7 +584,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, out, 0);
|
||||
compute_encoder.set_output_array(out, 0);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
|
41
mlx/backend/metal/reduce.h
Normal file
41
mlx/backend/metal/reduce.h
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright @ 2023 - 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
using metal::CommandEncoder;
|
||||
|
||||
void all_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
void row_reduce_general_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
void strided_reduce_general_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
@@ -63,14 +63,15 @@ void RoPE::eval_gpu(
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
kname << "rope_" << (forward_ ? "" : "vjp_")
|
||||
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
float base = std::log2(base_);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, donated ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&offset_, sizeof(int), 4);
|
||||
|
@@ -71,7 +71,7 @@ void sdpa_metal(
|
||||
|
||||
std::string kname_suffix = kname_suffix_tile_size + kname_suffix_nsimdgroups;
|
||||
kname_partials << kname_suffix;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname_partials.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -87,15 +87,15 @@ void sdpa_metal(
|
||||
MLXScaledDotProductAttentionParams params{
|
||||
query_sequence_length, n_q_heads, n_kv_heads, n_tiles, alpha};
|
||||
|
||||
set_array_buffer(compute_encoder, q, 0);
|
||||
set_array_buffer(compute_encoder, k, 1);
|
||||
set_array_buffer(compute_encoder, v, 2);
|
||||
compute_encoder.set_input_array(q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder->setBytes(&KV_sequence_length, sizeof(KV_sequence_length), 3);
|
||||
compute_encoder->setBytes(
|
||||
¶ms, sizeof(MLXScaledDotProductAttentionParams), 4);
|
||||
set_array_buffer(compute_encoder, o_partial, 5);
|
||||
set_array_buffer(compute_encoder, p_lse, 6);
|
||||
set_array_buffer(compute_encoder, p_rowmaxes, 7);
|
||||
compute_encoder.set_input_array(o_partial, 5);
|
||||
compute_encoder.set_input_array(p_lse, 6);
|
||||
compute_encoder.set_input_array(p_rowmaxes, 7);
|
||||
|
||||
constexpr const uint tgroupMemorySize = 32768;
|
||||
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
|
||||
@@ -104,12 +104,12 @@ void sdpa_metal(
|
||||
{
|
||||
auto kernel_accum = d.get_kernel(kname_reduce.str());
|
||||
compute_encoder->setComputePipelineState(kernel_accum);
|
||||
set_array_buffer(compute_encoder, o_partial, 0);
|
||||
set_array_buffer(compute_encoder, p_lse, 1);
|
||||
set_array_buffer(compute_encoder, p_rowmaxes, 2);
|
||||
compute_encoder.set_input_array(o_partial, 0);
|
||||
compute_encoder.set_input_array(p_lse, 1);
|
||||
compute_encoder.set_input_array(p_rowmaxes, 2);
|
||||
compute_encoder->setBytes(
|
||||
¶ms, sizeof(MLXScaledDotProductAttentionParams), 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
|
||||
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
|
||||
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
|
||||
@@ -127,7 +127,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
assert(inputs.size() >= 3);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention] Does not yet support non-floating point types.");
|
||||
}
|
||||
|
@@ -52,10 +52,10 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kname << type_to_name(in) << "_" << type_to_name(out);
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
|
||||
@@ -101,10 +101,10 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kname << type_to_name(in) << "_" << type_to_name(out);
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
size_t stride = in.strides()[axis_];
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
|
@@ -12,7 +12,7 @@ namespace mlx::core {
|
||||
|
||||
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[softmax] Does not support non-floating point types.");
|
||||
}
|
||||
@@ -21,7 +21,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
@@ -30,10 +30,9 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
const array& in = check_input(inputs[0]);
|
||||
@@ -57,8 +56,11 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "looped_";
|
||||
}
|
||||
if (in.dtype() != float32 && precise_) {
|
||||
op_name += "precise_";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
@@ -79,13 +81,10 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(
|
||||
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
|
||||
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0);
|
||||
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
@@ -57,13 +57,13 @@ void single_block_sort(
|
||||
}
|
||||
|
||||
// Prepare command encoder
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set inputs
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
|
||||
|
||||
@@ -102,6 +102,11 @@ void multi_block_sort(
|
||||
|
||||
int nc_dim = nc_shape.size();
|
||||
|
||||
if (nc_dim == 0) {
|
||||
nc_shape = {0};
|
||||
nc_str = {1};
|
||||
}
|
||||
|
||||
int size_sorted_axis = in.shape(axis);
|
||||
int stride_sorted_axis = in.strides()[axis];
|
||||
|
||||
@@ -126,7 +131,7 @@ void multi_block_sort(
|
||||
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
|
||||
|
||||
// Prepare command encoder
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
// Do blockwise sort
|
||||
{
|
||||
@@ -137,14 +142,15 @@ void multi_block_sort(
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, dev_vals_0, 1);
|
||||
set_array_buffer(compute_encoder, dev_idxs_0, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(dev_vals_0, 1);
|
||||
compute_encoder.set_output_array(dev_idxs_0, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(
|
||||
nc_shape.data(), nc_shape.size() * sizeof(int), 6);
|
||||
compute_encoder->setBytes(nc_str.data(), nc_str.size() * sizeof(size_t), 7);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||
@@ -158,7 +164,8 @@ void multi_block_sort(
|
||||
array dev_idxs_in = dev_idxs_0;
|
||||
array dev_vals_out = dev_vals_1;
|
||||
array dev_idxs_out = dev_idxs_1;
|
||||
for (int merge_tiles = 2; merge_tiles <= n_blocks; merge_tiles *= 2) {
|
||||
|
||||
for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {
|
||||
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
|
||||
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
|
||||
dev_vals_out = ping ? dev_vals_0 : dev_vals_1;
|
||||
@@ -174,9 +181,9 @@ void multi_block_sort(
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, block_partitions, 0);
|
||||
set_array_buffer(compute_encoder, dev_vals_in, 1);
|
||||
set_array_buffer(compute_encoder, dev_idxs_in, 2);
|
||||
compute_encoder.set_output_array(block_partitions, 0);
|
||||
compute_encoder.set_input_array(dev_vals_in, 1);
|
||||
compute_encoder.set_input_array(dev_idxs_in, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
|
||||
|
||||
@@ -195,11 +202,11 @@ void multi_block_sort(
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, block_partitions, 0);
|
||||
set_array_buffer(compute_encoder, dev_vals_in, 1);
|
||||
set_array_buffer(compute_encoder, dev_idxs_in, 2);
|
||||
set_array_buffer(compute_encoder, dev_vals_out, 3);
|
||||
set_array_buffer(compute_encoder, dev_idxs_out, 4);
|
||||
compute_encoder.set_input_array(block_partitions, 0);
|
||||
compute_encoder.set_input_array(dev_vals_in, 1);
|
||||
compute_encoder.set_input_array(dev_idxs_in, 2);
|
||||
compute_encoder.set_output_array(dev_vals_out, 3);
|
||||
compute_encoder.set_output_array(dev_idxs_out, 4);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&n_blocks, sizeof(int), 7);
|
||||
|
@@ -1,37 +1,20 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
inline void
|
||||
set_array_buffer(MTL::ComputeCommandEncoder* enc, const array& a, int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
enc->setBuffer(a_buf, offset, idx);
|
||||
}
|
||||
|
||||
inline void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const array& a,
|
||||
int64_t offset,
|
||||
int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
using metal::CommandEncoder;
|
||||
|
||||
template <typename T>
|
||||
inline void set_vector_bytes(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
CommandEncoder& enc,
|
||||
const std::vector<T>& vec,
|
||||
size_t nelems,
|
||||
int idx) {
|
||||
@@ -39,10 +22,8 @@ inline void set_vector_bytes(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void set_vector_bytes(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const std::vector<T>& vec,
|
||||
int idx) {
|
||||
inline void
|
||||
set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
|
||||
return set_vector_bytes(enc, vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
@@ -123,6 +104,32 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
||||
}
|
||||
|
||||
inline NS::String* make_string(std::ostringstream& os) {
|
||||
std::string string = os.str();
|
||||
return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
|
||||
}
|
||||
|
||||
inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
|
||||
#ifdef MLX_METAL_DEBUG
|
||||
std::ostringstream label;
|
||||
label << "Stream " << index;
|
||||
queue->setLabel(make_string(label));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void debug_set_primitive_buffer_label(
|
||||
MTL::CommandBuffer* command_buffer,
|
||||
Primitive& primitive) {
|
||||
#ifdef MLX_METAL_DEBUG
|
||||
std::ostringstream label;
|
||||
if (auto cbuf_label = command_buffer->label(); cbuf_label) {
|
||||
label << cbuf_label->utf8String();
|
||||
}
|
||||
primitive.print(label);
|
||||
command_buffer->setLabel(make_string(label));
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#include <stdexcept>
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
@@ -39,5 +40,9 @@ size_t set_memory_limit(size_t, bool) {
|
||||
size_t set_cache_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
bool start_capture(std::string path) {
|
||||
return false;
|
||||
}
|
||||
void stop_capture() {}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -49,6 +49,7 @@ NO_GPU(Equal)
|
||||
NO_GPU(Erf)
|
||||
NO_GPU(ErfInv)
|
||||
NO_GPU(Exp)
|
||||
NO_GPU(Expm1)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Floor)
|
||||
NO_GPU(Full)
|
||||
@@ -103,7 +104,9 @@ NO_GPU(Inverse)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(LayerNorm)
|
||||
NO_GPU_MULTI(LayerNormVJP)
|
||||
NO_GPU_MULTI(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
} // namespace fast
|
||||
|
103
mlx/compile.cpp
103
mlx/compile.cpp
@@ -32,7 +32,7 @@ bool is_unary(const Primitive& p) {
|
||||
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
|
||||
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
|
||||
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
|
||||
typeid(p) == typeid(Tanh));
|
||||
typeid(p) == typeid(Tanh) || typeid(p) == typeid(Expm1));
|
||||
}
|
||||
|
||||
bool is_binary(const Primitive& p) {
|
||||
@@ -162,44 +162,51 @@ CompileMode& compile_mode() {
|
||||
return compile_mode_;
|
||||
}
|
||||
|
||||
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
|
||||
using ParentsMap =
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||
|
||||
// Helper like below but only merges the two provided arrays. If the src has
|
||||
// siblings then these won't be merged to the dst.
|
||||
void merge_one(array& dst, array& src, ParentsMap& parents_map) {
|
||||
auto src_parents = parents_map.find(src.id());
|
||||
if (src_parents == parents_map.end()) {
|
||||
return;
|
||||
}
|
||||
auto& pairs = parents_map[dst.id()];
|
||||
for (auto& parent : src_parents->second) {
|
||||
parent.first.inputs()[parent.second] = dst;
|
||||
pairs.push_back(parent);
|
||||
}
|
||||
// Remove the source from the map to avoid fusing with it again
|
||||
parents_map.erase(src_parents);
|
||||
};
|
||||
|
||||
// Helper that merges two arrays in the graph by setting the parents of the
|
||||
// source to point to the destination
|
||||
// source to point to the destination. The arrays are assumed to be coming from
|
||||
// equivalent primitives so their siblings are merged as well.
|
||||
void merge(array& dst, array& src, ParentsMap& parents_map) {
|
||||
// Canonicalize the order of the primitives outputs
|
||||
auto sources = src.outputs();
|
||||
auto dests = dst.outputs();
|
||||
// For each src parent, point it to the corresponding dst
|
||||
for (int i = 0; i < sources.size(); ++i) {
|
||||
auto src_parents = parents_map.find(sources[i].id());
|
||||
if (src_parents == parents_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& pairs = parents_map[dests[i].id()];
|
||||
for (auto& parent : src_parents->second) {
|
||||
parent.first.inputs()[parent.second] = dests[i];
|
||||
pairs.push_back(parent);
|
||||
}
|
||||
// Remove the source from the map to avoid fusing with it again
|
||||
parents_map.erase(src_parents);
|
||||
merge_one(dests[i], sources[i], parents_map);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename... U>
|
||||
size_t getAddress(std::function<T(U...)> f) {
|
||||
typedef T(fnType)(U...);
|
||||
fnType** fnPointer = f.template target<fnType*>();
|
||||
if (fnPointer == nullptr) {
|
||||
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
||||
using FunType = T (*)(U...);
|
||||
const FunType* fun_ptr = fun.template target<FunType>();
|
||||
if (fun_ptr == nullptr) {
|
||||
throw std::invalid_argument(
|
||||
"[compile] Cannot compile a non-addressable function.");
|
||||
}
|
||||
return (size_t)*fnPointer;
|
||||
return reinterpret_cast<std::uintptr_t>(*fun_ptr);
|
||||
}
|
||||
|
||||
struct CompilerCache {
|
||||
class CompilerCache {
|
||||
public:
|
||||
struct CacheEntry {
|
||||
std::vector<array> inputs;
|
||||
std::vector<array> outputs;
|
||||
@@ -211,20 +218,20 @@ struct CompilerCache {
|
||||
// Returns a reference to a CacheEntry which can be updated
|
||||
// by the caller to avoid copying large tapes / inputs / outputs
|
||||
CacheEntry& find(
|
||||
size_t fun_id,
|
||||
std::uintptr_t fun_id,
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless,
|
||||
const std::vector<uint64_t>& constants) {
|
||||
// Try to find the entry
|
||||
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
|
||||
auto& entries = entry_it->second;
|
||||
auto is_match = [shapeless](
|
||||
const std::vector<array>& in1,
|
||||
const std::vector<array>& in2) {
|
||||
// Find the cache entries for |fun_id|.
|
||||
std::vector<CacheEntry>& entries = cache_[fun_id];
|
||||
// Compare if 2 arrays have same shape and dtype.
|
||||
auto has_same_shape_and_dtype = [shapeless](
|
||||
const std::vector<array>& in1,
|
||||
const std::vector<array>& in2) {
|
||||
if (in1.size() != in2.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < in1.size(); ++i) {
|
||||
for (size_t i = 0; i < in1.size(); ++i) {
|
||||
if (in1[i].ndim() != in2[i].ndim()) {
|
||||
return false;
|
||||
}
|
||||
@@ -237,14 +244,14 @@ struct CompilerCache {
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// Loop over entries and check inputs match i.e. shapes and types must be
|
||||
// equal. Note this could get really slow if one compiles the same
|
||||
// function with many different shapes. May want to store entries in a
|
||||
// more easily searchable structure.
|
||||
for (auto& entry : entries) {
|
||||
for (CacheEntry& entry : entries) {
|
||||
// Check the inputs match and return if so
|
||||
if (is_match(inputs, entry.inputs) && constants == entry.constants) {
|
||||
if (has_same_shape_and_dtype(inputs, entry.inputs) &&
|
||||
constants == entry.constants) {
|
||||
return entry;
|
||||
}
|
||||
}
|
||||
@@ -253,7 +260,7 @@ struct CompilerCache {
|
||||
return entries.back();
|
||||
};
|
||||
|
||||
void erase(size_t fun_id) {
|
||||
void erase(std::uintptr_t fun_id) {
|
||||
cache_.erase(fun_id);
|
||||
}
|
||||
|
||||
@@ -263,8 +270,9 @@ struct CompilerCache {
|
||||
// initialized before the compiler cache
|
||||
allocator::allocator();
|
||||
}
|
||||
|
||||
friend CompilerCache& compiler_cache();
|
||||
std::unordered_map<size_t, std::vector<CacheEntry>> cache_;
|
||||
std::unordered_map<std::uintptr_t, std::vector<CacheEntry>> cache_;
|
||||
};
|
||||
|
||||
CompilerCache& compiler_cache() {
|
||||
@@ -523,9 +531,14 @@ void compile_fuse(
|
||||
// - Collect inputs to the new compiled primitive
|
||||
// - Add fusable primitives to a tape in the correct order
|
||||
|
||||
std::function<void(const array&, int, const Stream&)> recurse;
|
||||
std::function<void(
|
||||
const array&, int, const Stream&, const std::vector<int>&)>
|
||||
recurse;
|
||||
std::unordered_set<uintptr_t> cache;
|
||||
recurse = [&](const array& a, int depth, const Stream& s) {
|
||||
recurse = [&](const array& a,
|
||||
int depth,
|
||||
const Stream& s,
|
||||
const std::vector<int>& shape) {
|
||||
if (cache.find(a.id()) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
@@ -535,8 +548,10 @@ void compile_fuse(
|
||||
// - Constant input
|
||||
// - Stream mismatch
|
||||
// - Non fusable primitive
|
||||
// - Is global output but has a different shape
|
||||
if (depth >= max_compile_depth || !a.has_primitive() ||
|
||||
a.primitive().stream() != s || !is_fusable(a.primitive())) {
|
||||
a.primitive().stream() != s || !is_fusable(a.primitive()) ||
|
||||
(output_map.find(a.id()) != output_map.end() && a.shape() != shape)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -563,13 +578,13 @@ void compile_fuse(
|
||||
cache.insert({a.id()});
|
||||
|
||||
for (auto& in : a.inputs()) {
|
||||
recurse(in, depth + 1, s);
|
||||
recurse(in, depth + 1, s, shape);
|
||||
}
|
||||
};
|
||||
|
||||
if (arr.has_primitive()) {
|
||||
Stream s = arr.primitive().stream();
|
||||
recurse(arr, 0, s);
|
||||
recurse(arr, 0, s, arr.shape());
|
||||
}
|
||||
|
||||
// Not worth fusing a single primitive
|
||||
@@ -633,6 +648,10 @@ void compile_fuse(
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Dtype> types;
|
||||
for (auto& o : old_outputs) {
|
||||
if (o.shape() != old_outputs.back().shape()) {
|
||||
throw std::runtime_error(
|
||||
"[compile] Compilation failed. Tried to fuse operations with different output shapes");
|
||||
}
|
||||
shapes.push_back(o.shape());
|
||||
types.push_back(o.dtype());
|
||||
}
|
||||
@@ -675,7 +694,7 @@ void compile_fuse(
|
||||
// - Update outputs parents to point to compiled outputs
|
||||
// - Update any overall graph outputs to be compiled outputs
|
||||
for (int o = 0; o < old_outputs.size(); ++o) {
|
||||
merge(compiled_outputs[o], old_outputs[o], parents_map);
|
||||
merge_one(compiled_outputs[o], old_outputs[o], parents_map);
|
||||
if (auto it = output_map.find(old_outputs[o].id());
|
||||
it != output_map.end()) {
|
||||
it->second = compiled_outputs[o];
|
||||
@@ -774,7 +793,7 @@ void compile_validate_shapeless(const std::vector<array>& tape) {
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
size_t fun_id,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless /* = false */,
|
||||
std::vector<uint64_t> constants /* = {} */) {
|
||||
if (compile_mode() == CompileMode::disabled ||
|
||||
@@ -833,7 +852,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
};
|
||||
}
|
||||
|
||||
void compile_erase(size_t fun_id) {
|
||||
void compile_erase(std::uintptr_t fun_id) {
|
||||
detail::compiler_cache().erase(fun_id);
|
||||
}
|
||||
|
||||
@@ -845,7 +864,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
if (detail::compile_mode() == CompileMode::disabled) {
|
||||
return fun;
|
||||
}
|
||||
auto fun_id = detail::getAddress(fun);
|
||||
auto fun_id = detail::get_function_address(fun);
|
||||
return detail::compile(fun, fun_id, shapeless);
|
||||
}
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdint>
|
||||
#include <sstream>
|
||||
@@ -12,6 +12,7 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
constexpr int num_types = 13;
|
||||
constexpr int num_cats = 8;
|
||||
|
||||
constexpr Dtype::Kind type_kinds[num_types] = {
|
||||
Dtype::Kind::b, // bool_,
|
||||
@@ -49,18 +50,37 @@ constexpr Dtype type_rules[num_types][num_types] = {
|
||||
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64
|
||||
};
|
||||
|
||||
|
||||
constexpr bool subcategory_to_category[num_cats][num_cats] = {
|
||||
// complexfloating floating inexact signedinteger unsignedinteger integer number generic
|
||||
{true, false, true, false, false, false, true, true}, // complexfloating
|
||||
{false, true, true, false, false, false, true, true}, // floating
|
||||
{false, false, true, false, false, false, true, true}, // inexact
|
||||
{false, false, false, true, false, true, true, true}, // signedinteger
|
||||
{false, false, false, false, true, true, true, true}, // unsignedinteger
|
||||
{false, false, false, false, false, true, true, true}, // integer
|
||||
{false, false, false, false, false, false, true, true}, // number
|
||||
{false, false, false, false, false, false, false, true}, // generic
|
||||
};
|
||||
|
||||
constexpr Dtype::Category type_to_category[num_types] = {
|
||||
Dtype::Category::generic, // bool_,
|
||||
Dtype::Category::unsignedinteger, // uint8,
|
||||
Dtype::Category::unsignedinteger, // uint16,
|
||||
Dtype::Category::unsignedinteger, // uint32,
|
||||
Dtype::Category::unsignedinteger, // uint64,
|
||||
Dtype::Category::signedinteger, // int8,
|
||||
Dtype::Category::signedinteger, // int16,
|
||||
Dtype::Category::signedinteger, // int32,
|
||||
Dtype::Category::signedinteger, // int64,
|
||||
Dtype::Category::floating, // float16,
|
||||
Dtype::Category::floating, // float32,
|
||||
Dtype::Category::floating, // bfloat16,
|
||||
Dtype::Category::complexfloating, // complex64,
|
||||
};
|
||||
|
||||
// clang-format on
|
||||
|
||||
inline bool is_big_endian() {
|
||||
union ByteOrder {
|
||||
int32_t i;
|
||||
uint8_t c[4];
|
||||
};
|
||||
ByteOrder b = {0x01234567};
|
||||
|
||||
return b.c[0] == 0x01;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Dtype promote_types(const Dtype& t1, const Dtype& t2) {
|
||||
@@ -141,6 +161,23 @@ TypeToDtype<complex64_t>::operator Dtype() {
|
||||
return complex64;
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype& a, const Dtype& b) {
|
||||
return a == b;
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype::Category& cat, const Dtype& type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype& type, const Dtype::Category& cat) {
|
||||
return issubdtype(type_to_category[static_cast<uint32_t>(type.val)], cat);
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
|
||||
return subcategory_to_category[static_cast<uint32_t>(a)]
|
||||
[static_cast<uint32_t>(b)];
|
||||
}
|
||||
|
||||
// Array protocol typestring for Dtype
|
||||
std::string dtype_to_array_protocol(const Dtype& t) {
|
||||
std::ostringstream r;
|
||||
@@ -153,9 +190,9 @@ std::string dtype_to_array_protocol(const Dtype& t) {
|
||||
}
|
||||
|
||||
// Dtype from array protocol type string
|
||||
Dtype dtype_from_array_protocol(const std::string& t) {
|
||||
Dtype dtype_from_array_protocol(std::string_view t) {
|
||||
if (t.length() == 2 || t.length() == 3) {
|
||||
std::string r = t.length() == 3 ? t.substr(1, 2) : t;
|
||||
std::string_view r = t.length() == 3 ? t.substr(1, 2) : t;
|
||||
|
||||
if (r == "V2") {
|
||||
return bfloat16;
|
||||
@@ -201,7 +238,7 @@ Dtype dtype_from_array_protocol(const std::string& t) {
|
||||
}
|
||||
|
||||
throw std::invalid_argument(
|
||||
"[from_str] Invalid array protocol type-string: " + t);
|
||||
"[from_str] Invalid array protocol type-string: " + std::string(t));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
48
mlx/dtype.h
48
mlx/dtype.h
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -38,6 +38,17 @@ struct Dtype {
|
||||
V, /* void - used for brain float */
|
||||
};
|
||||
|
||||
enum class Category {
|
||||
complexfloating,
|
||||
floating,
|
||||
inexact,
|
||||
signedinteger,
|
||||
unsignedinteger,
|
||||
integer,
|
||||
number,
|
||||
generic
|
||||
};
|
||||
|
||||
Val val;
|
||||
const uint8_t size;
|
||||
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
|
||||
@@ -63,6 +74,22 @@ inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
|
||||
inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
|
||||
inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
|
||||
|
||||
inline constexpr Dtype::Category complexfloating =
|
||||
Dtype::Category::complexfloating;
|
||||
inline constexpr Dtype::Category floating = Dtype::Category::floating;
|
||||
inline constexpr Dtype::Category inexact = Dtype::Category::inexact;
|
||||
inline constexpr Dtype::Category signedinteger = Dtype::Category::signedinteger;
|
||||
inline constexpr Dtype::Category unsignedinteger =
|
||||
Dtype::Category::unsignedinteger;
|
||||
inline constexpr Dtype::Category integer = Dtype::Category::integer;
|
||||
inline constexpr Dtype::Category number = Dtype::Category::number;
|
||||
inline constexpr Dtype::Category generic = Dtype::Category::generic;
|
||||
|
||||
bool issubdtype(const Dtype& a, const Dtype& b);
|
||||
bool issubdtype(const Dtype::Category& a, const Dtype& b);
|
||||
bool issubdtype(const Dtype& a, const Dtype::Category& b);
|
||||
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b);
|
||||
|
||||
Dtype promote_types(const Dtype& t1, const Dtype& t2);
|
||||
|
||||
inline uint8_t size_of(const Dtype& t) {
|
||||
@@ -71,23 +98,6 @@ inline uint8_t size_of(const Dtype& t) {
|
||||
|
||||
Dtype::Kind kindof(const Dtype& t);
|
||||
|
||||
inline bool is_unsigned(const Dtype& t) {
|
||||
return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b;
|
||||
}
|
||||
|
||||
inline bool is_floating_point(const Dtype& t) {
|
||||
return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V ||
|
||||
kindof(t) == Dtype::Kind::c;
|
||||
}
|
||||
|
||||
inline bool is_complex(const Dtype& t) {
|
||||
return kindof(t) == Dtype::Kind::c;
|
||||
}
|
||||
|
||||
inline bool is_integral(const Dtype& t) {
|
||||
return !(is_floating_point(t));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct TypeToDtype {
|
||||
operator Dtype();
|
||||
@@ -96,6 +106,6 @@ struct TypeToDtype {
|
||||
// Array protocol typestring for Dtype
|
||||
std::string dtype_to_array_protocol(const Dtype& t);
|
||||
// Dtype from array protocol type string
|
||||
Dtype dtype_from_array_protocol(const std::string& t);
|
||||
Dtype dtype_from_array_protocol(std::string_view t);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
239
mlx/fast.cpp
239
mlx/fast.cpp
@@ -1,5 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
@@ -64,7 +67,7 @@ array rms_norm(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto out_type = result_type(x, weight);
|
||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] Received unsupported type " << out_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
@@ -94,11 +97,69 @@ array rms_norm(
|
||||
return fallback({x, weight})[0];
|
||||
}
|
||||
|
||||
std::vector<array> RMSNorm::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(primals.size() == 2);
|
||||
assert(outputs.size() == 1);
|
||||
assert(cotangents.size() == 1);
|
||||
|
||||
auto s = stream();
|
||||
auto fallback = [eps = eps_, s](const std::vector<array>& inputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& w = inputs[1];
|
||||
auto& g = inputs[2];
|
||||
|
||||
std::vector<array> vjps;
|
||||
|
||||
auto n = rsqrt(
|
||||
add(mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s),
|
||||
array(eps, x.dtype()),
|
||||
s),
|
||||
s);
|
||||
auto n3 = power(n, array(3, x.dtype()), s);
|
||||
|
||||
// df/dx
|
||||
auto gw = multiply(g, w, s);
|
||||
auto t = mean(multiply(gw, x, s), /* axis= */ -1, /* keepdims= */ true, s);
|
||||
t = multiply(multiply(x, t, s), n3, s);
|
||||
vjps.push_back(subtract(multiply(gw, n, s), t, s));
|
||||
|
||||
// df/dw
|
||||
std::vector<int> axes(g.ndim() - 1);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
vjps.push_back(
|
||||
sum(multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s));
|
||||
|
||||
return vjps;
|
||||
};
|
||||
|
||||
auto vjps = array::make_arrays(
|
||||
{primals[0].shape(), primals[1].shape()},
|
||||
{primals[0].dtype(), primals[1].dtype()},
|
||||
std::make_shared<RMSNormVJP>(s, fallback, eps_),
|
||||
{primals[0], primals[1], cotangents[0]});
|
||||
|
||||
std::vector<array> returned_vjps;
|
||||
for (auto& arg : argnums) {
|
||||
returned_vjps.push_back(std::move(vjps[arg]));
|
||||
}
|
||||
|
||||
return returned_vjps;
|
||||
}
|
||||
|
||||
bool RMSNorm::is_equivalent(const Primitive& other) const {
|
||||
const RMSNorm& a_other = static_cast<const RMSNorm&>(other);
|
||||
return eps_ == a_other.eps_;
|
||||
}
|
||||
|
||||
bool RMSNormVJP::is_equivalent(const Primitive& other) const {
|
||||
const RMSNormVJP& a_other = static_cast<const RMSNormVJP&>(other);
|
||||
return eps_ == a_other.eps_;
|
||||
}
|
||||
|
||||
array layer_norm(
|
||||
const array& x,
|
||||
const std::optional<array>& weight,
|
||||
@@ -128,7 +189,7 @@ array layer_norm(
|
||||
? ((bias.has_value()) ? result_type(x, *weight, *bias)
|
||||
: result_type(x, *weight))
|
||||
: x.dtype();
|
||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[layer_norm] Received unsupported type " << out_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
@@ -176,11 +237,90 @@ array layer_norm(
|
||||
return fallback({x, passed_weight, passed_bias})[0];
|
||||
}
|
||||
|
||||
std::vector<array> LayerNorm::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(primals.size() == 3);
|
||||
assert(outputs.size() == 1);
|
||||
assert(cotangents.size() == 1);
|
||||
|
||||
auto s = stream();
|
||||
auto fallback = [eps = eps_, s](const std::vector<array>& inputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& w = inputs[1];
|
||||
auto& b = inputs[2];
|
||||
auto& g = inputs[3];
|
||||
|
||||
std::vector<array> vjps;
|
||||
|
||||
auto norm = number_of_elements(x, {-1}, true, x.dtype(), s);
|
||||
auto sumx = sum(x, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
auto sumx2 = sum(square(x, s), /* axis= */ -1, /* keepdims= */ true, s);
|
||||
auto mu = multiply(sumx, norm, s);
|
||||
auto mu2 = multiply(sumx2, norm, s);
|
||||
auto var = subtract(mu2, square(mu, s), s);
|
||||
auto n = rsqrt(add(var, array(eps, x.dtype()), s));
|
||||
auto n3 = power(n, array(3, x.dtype()), s);
|
||||
auto x_c = subtract(x, mu, s);
|
||||
|
||||
// df/dx
|
||||
auto wg = multiply(w, g, s);
|
||||
auto sumwg =
|
||||
multiply(sum(wg, /* axis= */ -1, /* keepdims= */ true, s), norm, s);
|
||||
auto sumwgxc = multiply(
|
||||
sum(multiply(wg, x_c, s), /* axis= */ -1, /* keepdims= */ true, s),
|
||||
norm,
|
||||
s);
|
||||
auto t1 = multiply(multiply(x_c, sumwgxc, s), n3, s);
|
||||
auto t2 = multiply(subtract(wg, sumwg, s), n, s);
|
||||
vjps.push_back(subtract(t2, t1, s));
|
||||
|
||||
// df/dw
|
||||
std::vector<int> axes(g.ndim() - 1);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
if (w.ndim() == 0) {
|
||||
vjps.push_back(zeros_like(w, s));
|
||||
} else {
|
||||
vjps.push_back(sum(
|
||||
multiply(g, multiply(x_c, n, s), s), axes, /* keepdims= */ false, s));
|
||||
}
|
||||
|
||||
// df/db
|
||||
if (b.ndim() == 0) {
|
||||
vjps.push_back(zeros_like(w, s));
|
||||
} else {
|
||||
vjps.push_back(sum(g, axes, /* keepdims= */ false, s));
|
||||
}
|
||||
|
||||
return vjps;
|
||||
};
|
||||
|
||||
auto vjps = array::make_arrays(
|
||||
{primals[0].shape(), primals[1].shape(), primals[2].shape()},
|
||||
{primals[0].dtype(), primals[1].dtype(), primals[2].dtype()},
|
||||
std::make_shared<LayerNormVJP>(s, fallback, eps_),
|
||||
{primals[0], primals[1], primals[2], cotangents[0]});
|
||||
|
||||
std::vector<array> returned_vjps;
|
||||
for (auto& arg : argnums) {
|
||||
returned_vjps.push_back(std::move(vjps[arg]));
|
||||
}
|
||||
|
||||
return returned_vjps;
|
||||
}
|
||||
|
||||
bool LayerNorm::is_equivalent(const Primitive& other) const {
|
||||
const LayerNorm& a_other = static_cast<const LayerNorm&>(other);
|
||||
return eps_ == a_other.eps_;
|
||||
}
|
||||
|
||||
bool LayerNormVJP::is_equivalent(const Primitive& other) const {
|
||||
const LayerNormVJP& a_other = static_cast<const LayerNormVJP&>(other);
|
||||
return eps_ == a_other.eps_;
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
@@ -188,19 +328,16 @@ array rope(
|
||||
float base,
|
||||
float scale,
|
||||
int offset,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
bool forward,
|
||||
StreamOrDevice s) {
|
||||
if (x.ndim() < 3) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] Input must have at least 3 dimensions but got input with "
|
||||
<< x.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (traditional && x.shape(-1) != dims) {
|
||||
throw std::invalid_argument(
|
||||
"[rope] Does not support partial traditional application.");
|
||||
}
|
||||
|
||||
auto fallback = [dims, traditional, base, scale, offset, s](
|
||||
auto fallback = [dims, traditional, base, scale, offset, forward, s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto& shape = inputs[0].shape();
|
||||
int ndim = shape.size();
|
||||
@@ -217,16 +354,39 @@ array rope(
|
||||
auto coss = cos(theta, s);
|
||||
auto sins = sin(theta, s);
|
||||
|
||||
if (traditional) {
|
||||
auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s);
|
||||
auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s);
|
||||
auto apply_rope = [forward, s](
|
||||
const array& x1,
|
||||
const array& x2,
|
||||
const array& coss,
|
||||
const array& sins) {
|
||||
std::vector<array> outs;
|
||||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||
if (forward) {
|
||||
outs.push_back(
|
||||
subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||
} else {
|
||||
outs.push_back(add(multiply(x2, sins, s), multiply(x1, coss, s), s));
|
||||
outs.push_back(
|
||||
subtract(multiply(x2, coss, s), multiply(x1, sins, s), s));
|
||||
}
|
||||
return outs;
|
||||
};
|
||||
|
||||
if (traditional) {
|
||||
auto x1 =
|
||||
slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
||||
auto x2 =
|
||||
slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
||||
auto outs = apply_rope(x1, x2, coss, sins);
|
||||
for (auto& o : outs) {
|
||||
o = expand_dims(o, 3, s);
|
||||
}
|
||||
return std::vector<array>{reshape(concatenate(outs, 3, s), shape, s)};
|
||||
auto out = concatenate(outs, 3, s);
|
||||
if (dims < x.shape(-1)) {
|
||||
out = reshape(out, {x.shape(0), x.shape(1), dims});
|
||||
out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s);
|
||||
}
|
||||
return std::vector<array>{reshape(out, shape, s)};
|
||||
} else {
|
||||
auto out_s = x.shape();
|
||||
out_s.back() = half_dims;
|
||||
@@ -234,9 +394,7 @@ array rope(
|
||||
out_s.back() = dims;
|
||||
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
|
||||
|
||||
std::vector<array> outs;
|
||||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||
auto outs = apply_rope(x1, x2, coss, sins);
|
||||
if (dims < x.shape(-1)) {
|
||||
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
|
||||
}
|
||||
@@ -249,18 +407,54 @@ array rope(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<RoPE>(
|
||||
stream, fallback, dims, traditional, base, scale, offset),
|
||||
stream, fallback, dims, traditional, base, scale, offset, forward),
|
||||
{x});
|
||||
}
|
||||
return fallback({x})[0];
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
float scale,
|
||||
int offset,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return rope(x, dims, traditional, base, scale, offset, true, s);
|
||||
}
|
||||
|
||||
std::vector<array> RoPE::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
auto s = stream();
|
||||
auto fallback = [dims = dims_,
|
||||
traditional = traditional_,
|
||||
base = base_,
|
||||
scale = scale_,
|
||||
offset = offset_,
|
||||
forward = forward_,
|
||||
s](std::vector<array> inputs) {
|
||||
return std::vector<array>{
|
||||
rope(inputs[0], dims, traditional, base, scale, offset, !forward, s)};
|
||||
};
|
||||
|
||||
return {array(
|
||||
cotangents[0].shape(),
|
||||
cotangents[0].dtype(),
|
||||
std::make_shared<RoPE>(
|
||||
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
|
||||
cotangents)};
|
||||
}
|
||||
|
||||
bool RoPE::is_equivalent(const Primitive& other) const {
|
||||
const RoPE& a_other = static_cast<const RoPE&>(other);
|
||||
return (
|
||||
dims_ == a_other.dims_ && base_ == a_other.base_ &&
|
||||
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
|
||||
offset_ == a_other.offset_);
|
||||
offset_ == a_other.offset_ && forward_ == a_other.forward_);
|
||||
}
|
||||
|
||||
/** Computes: O = softmax(Q @ K.T) @ V **/
|
||||
@@ -319,7 +513,7 @@ array scaled_dot_product_attention(
|
||||
}
|
||||
|
||||
auto final_type = result_type(queries, keys, values);
|
||||
if (!is_floating_point(final_type) || is_complex(final_type)) {
|
||||
if (!issubdtype(final_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Received unsupported type "
|
||||
<< final_type << ".";
|
||||
@@ -356,10 +550,7 @@ array scaled_dot_product_attention(
|
||||
if (needs_mask) {
|
||||
scores = add(scores, inputs[3], s);
|
||||
}
|
||||
scores = astype(
|
||||
softmax(astype(scores, float32, s), std::vector<int>{-1}, s),
|
||||
final_type,
|
||||
s);
|
||||
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||
auto out = matmul(scores, v, s);
|
||||
if (n_repeats > 1) {
|
||||
out = reshape(out, {B, n_q_heads, L, -1}, s);
|
||||
|
@@ -48,6 +48,12 @@ class RMSNorm : public Custom {
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(RMSNorm)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
@@ -56,6 +62,29 @@ class RMSNorm : public Custom {
|
||||
float eps_;
|
||||
};
|
||||
|
||||
class RMSNormVJP : public Custom {
|
||||
public:
|
||||
RMSNormVJP(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(RMSNormVJP)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
class LayerNorm : public Custom {
|
||||
public:
|
||||
LayerNorm(
|
||||
@@ -71,6 +100,12 @@ class LayerNorm : public Custom {
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(LayerNorm)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
@@ -79,6 +114,29 @@ class LayerNorm : public Custom {
|
||||
float eps_;
|
||||
};
|
||||
|
||||
class LayerNormVJP : public Custom {
|
||||
public:
|
||||
LayerNormVJP(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(LayerNormVJP)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
class RoPE : public Custom {
|
||||
public:
|
||||
RoPE(
|
||||
@@ -88,13 +146,15 @@ class RoPE : public Custom {
|
||||
bool traditional,
|
||||
float base,
|
||||
float scale,
|
||||
int offset)
|
||||
int offset,
|
||||
bool forward)
|
||||
: Custom(stream, fallback),
|
||||
dims_(dims),
|
||||
traditional_(traditional),
|
||||
base_(base),
|
||||
scale_(scale),
|
||||
offset_(offset){};
|
||||
offset_(offset),
|
||||
forward_(forward){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
@@ -103,6 +163,12 @@ class RoPE : public Custom {
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(RoPE)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
@@ -113,6 +179,7 @@ class RoPE : public Custom {
|
||||
float base_;
|
||||
float scale_;
|
||||
int offset_;
|
||||
bool forward_;
|
||||
};
|
||||
|
||||
class ScaledDotProductAttention : public Custom {
|
||||
@@ -126,7 +193,7 @@ class ScaledDotProductAttention : public Custom {
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
outputs[0] = fallback_(inputs)[0];
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
|
@@ -95,7 +95,7 @@ array fft_impl(
|
||||
return array(
|
||||
out_shape,
|
||||
out_type,
|
||||
std::make_unique<FFT>(to_stream(s), valid_axes, inverse, real),
|
||||
std::make_shared<FFT>(to_stream(s), valid_axes, inverse, real),
|
||||
{astype(in, in_type, s)});
|
||||
}
|
||||
|
||||
|
@@ -6,12 +6,10 @@
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "stream.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mlx::core::fft {
|
||||
|
||||
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
||||
|
||||
/** Compute the n-dimensional Fourier Transform. */
|
||||
array fftn(
|
||||
const array& a,
|
||||
|
@@ -23,8 +23,7 @@ const std::string& NodeNamer::get_name(const array& x) {
|
||||
letters.push_back('A' + (var_num - 1) % 26);
|
||||
var_num = (var_num - 1) / 26;
|
||||
}
|
||||
std::string name(letters.rbegin(), letters.rend());
|
||||
names.insert({x.id(), name});
|
||||
names.emplace(x.id(), std::string(letters.rbegin(), letters.rend()));
|
||||
|
||||
return get_name(x);
|
||||
}
|
||||
|
@@ -14,15 +14,15 @@ struct NodeNamer {
|
||||
|
||||
void print_graph(std::ostream& os, const std::vector<array>& outputs);
|
||||
|
||||
template <typename... Arrays>
|
||||
void print_graph(std::ostream& os, Arrays... outputs) {
|
||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||
void print_graph(std::ostream& os, Arrays&&... outputs) {
|
||||
print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||
}
|
||||
|
||||
void export_to_dot(std::ostream& os, const std::vector<array>& outputs);
|
||||
|
||||
template <typename... Arrays>
|
||||
void export_to_dot(std::ostream& os, Arrays... outputs) {
|
||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||
void export_to_dot(std::ostream& os, Arrays&&... outputs) {
|
||||
export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||
}
|
||||
|
||||
|
6
mlx/io.h
6
mlx/io.h
@@ -23,13 +23,13 @@ using SafetensorsLoad = std::pair<
|
||||
void save(std::shared_ptr<io::Writer> out_stream, array a);
|
||||
|
||||
/** Save array to file in .npy format */
|
||||
void save(const std::string& file, array a);
|
||||
void save(std::string file, array a);
|
||||
|
||||
/** Load array from reader in .npy format */
|
||||
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
||||
|
||||
/** Load array from file in .npy format */
|
||||
array load(const std::string& file, StreamOrDevice s = {});
|
||||
array load(std::string file, StreamOrDevice s = {});
|
||||
|
||||
/** Load array map from .safetensors file format */
|
||||
SafetensorsLoad load_safetensors(
|
||||
@@ -44,7 +44,7 @@ void save_safetensors(
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, std::string> metadata = {});
|
||||
void save_safetensors(
|
||||
const std::string& file,
|
||||
std::string file,
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, std::string> metadata = {});
|
||||
|
||||
|
@@ -206,7 +206,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
||||
std::unordered_map<std::string, array> array_map;
|
||||
gguf_tensor tensor;
|
||||
|
||||
auto check_insert = [](auto inserted) {
|
||||
auto check_insert = [](const auto& inserted) {
|
||||
if (!inserted.second) {
|
||||
std::ostringstream msg;
|
||||
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
|
||||
@@ -216,6 +216,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
||||
};
|
||||
|
||||
while (gguf_get_tensor(ctx, &tensor)) {
|
||||
std::string name(tensor.name, tensor.namelen);
|
||||
if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 ||
|
||||
tensor.type == GGUF_TYPE_Q8_0) {
|
||||
gguf_load_quantized(array_map, tensor);
|
||||
@@ -224,14 +225,14 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
||||
|
||||
const auto& [data, dtype] = extract_tensor_data(&tensor);
|
||||
array loaded_array = array(data, get_shape(tensor), dtype);
|
||||
array_map.insert({name, loaded_array});
|
||||
check_insert(array_map.insert({name, loaded_array}));
|
||||
}
|
||||
}
|
||||
return array_map;
|
||||
}
|
||||
|
||||
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
|
||||
gguf_ctx* ctx = gguf_open(file.c_str());
|
||||
gguf_ctx* ctx = gguf_open(file.data());
|
||||
if (!ctx) {
|
||||
throw std::runtime_error("[load_gguf] gguf_init failed");
|
||||
}
|
||||
|
@@ -105,7 +105,8 @@ void gguf_load_quantized(
|
||||
weights_per_byte = 1;
|
||||
}
|
||||
|
||||
std::string name = std::string(tensor.name, tensor.namelen);
|
||||
std::string name(tensor.name, tensor.namelen);
|
||||
|
||||
std::vector<int> shape = get_shape(tensor);
|
||||
const uint64_t weights_per_block = 32;
|
||||
if (shape[shape.size() - 1] % weights_per_block != 0) {
|
||||
@@ -136,9 +137,9 @@ void gguf_load_quantized(
|
||||
extract_q8_0_data(tensor, weights, scales, biases);
|
||||
}
|
||||
|
||||
a.insert({name, weights});
|
||||
a.emplace(name, std::move(weights));
|
||||
|
||||
auto check_insert = [](auto inserted) {
|
||||
auto check_insert = [](const auto& inserted) {
|
||||
if (!inserted.second) {
|
||||
std::ostringstream msg;
|
||||
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
|
||||
@@ -147,11 +148,11 @@ void gguf_load_quantized(
|
||||
}
|
||||
};
|
||||
|
||||
const std::string weight_suffix = ".weight";
|
||||
constexpr std::string_view weight_suffix = ".weight";
|
||||
const std::string name_prefix =
|
||||
name.substr(0, name.length() - weight_suffix.length());
|
||||
check_insert(a.insert({name_prefix + ".scales", scales}));
|
||||
check_insert(a.insert({name_prefix + ".biases", biases}));
|
||||
check_insert(a.emplace(name_prefix + ".scales", std::move(scales)));
|
||||
check_insert(a.emplace(name_prefix + ".biases", std::move(biases)));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -26,16 +26,6 @@ constexpr uint8_t MAGIC[] = {
|
||||
0x59,
|
||||
};
|
||||
|
||||
inline bool is_big_endian_() {
|
||||
union ByteOrder {
|
||||
int32_t i;
|
||||
uint8_t c[4];
|
||||
};
|
||||
ByteOrder b = {0x01234567};
|
||||
|
||||
return b.c[0] == 0x01;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
@@ -73,8 +63,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||
std::string fortran_order = a.flags().col_contiguous ? "True" : "False";
|
||||
std::ostringstream header;
|
||||
header << "{'descr': '" << dtype_to_array_protocol(a.dtype()) << "',"
|
||||
<< " 'fortran_order': " << fortran_order << ","
|
||||
<< " 'shape': (";
|
||||
<< " 'fortran_order': " << fortran_order << "," << " 'shape': (";
|
||||
for (auto i : a.shape()) {
|
||||
header << i << ", ";
|
||||
}
|
||||
@@ -94,7 +83,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||
uint16_t v1_header_len = header.tellp();
|
||||
const char* len_bytes = reinterpret_cast<const char*>(&v1_header_len);
|
||||
|
||||
if (!is_big_endian_()) {
|
||||
if (!is_big_endian()) {
|
||||
magic_ver_len.write(len_bytes, 2);
|
||||
} else {
|
||||
magic_ver_len.write(len_bytes + 1, 1);
|
||||
@@ -106,7 +95,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||
uint32_t v2_header_len = header.tellp();
|
||||
const char* len_bytes = reinterpret_cast<const char*>(&v2_header_len);
|
||||
|
||||
if (!is_big_endian_()) {
|
||||
if (!is_big_endian()) {
|
||||
magic_ver_len.write(len_bytes, 4);
|
||||
} else {
|
||||
magic_ver_len.write(len_bytes + 3, 1);
|
||||
@@ -124,16 +113,13 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||
}
|
||||
|
||||
/** Save array to file in .npy format */
|
||||
void save(const std::string& file_, array a) {
|
||||
// Open and check file
|
||||
std::string file = file_;
|
||||
|
||||
void save(std::string file, array a) {
|
||||
// Add .npy to file name if it is not there
|
||||
if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy")
|
||||
file += ".npy";
|
||||
|
||||
// Serialize array
|
||||
save(std::make_shared<io::FileWriter>(file), a);
|
||||
save(std::make_shared<io::FileWriter>(std::move(file)), a);
|
||||
}
|
||||
|
||||
/** Load array from reader in .npy format */
|
||||
@@ -219,7 +205,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||
// Build primitive
|
||||
|
||||
size_t offset = 8 + header_len_size + header.length();
|
||||
bool swap_endianness = read_is_big_endian != is_big_endian_();
|
||||
bool swap_endianness = read_is_big_endian != is_big_endian();
|
||||
|
||||
if (col_contiguous) {
|
||||
std::reverse(shape.begin(), shape.end());
|
||||
@@ -227,7 +213,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||
auto loaded_array = array(
|
||||
shape,
|
||||
dtype,
|
||||
std::make_unique<Load>(to_stream(s), in_stream, offset, swap_endianness),
|
||||
std::make_shared<Load>(to_stream(s), in_stream, offset, swap_endianness),
|
||||
std::vector<array>{});
|
||||
if (col_contiguous) {
|
||||
loaded_array = transpose(loaded_array, s);
|
||||
@@ -237,8 +223,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||
}
|
||||
|
||||
/** Load array from file in .npy format */
|
||||
array load(const std::string& file, StreamOrDevice s) {
|
||||
return load(std::make_shared<io::FileReader>(file), s);
|
||||
array load(std::string file, StreamOrDevice s) {
|
||||
return load(std::make_shared<io::FileReader>(std::move(file)), s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -60,7 +60,7 @@ std::string dtype_to_safetensor_str(Dtype t) {
|
||||
}
|
||||
}
|
||||
|
||||
Dtype dtype_from_safetensor_str(std::string str) {
|
||||
Dtype dtype_from_safetensor_str(std::string_view str) {
|
||||
if (str == ST_F32) {
|
||||
return float32;
|
||||
} else if (str == ST_F16) {
|
||||
@@ -88,7 +88,8 @@ Dtype dtype_from_safetensor_str(std::string str) {
|
||||
} else if (str == ST_C64) {
|
||||
return complex64;
|
||||
} else {
|
||||
throw std::runtime_error("[safetensor] unsupported dtype " + str);
|
||||
throw std::runtime_error(
|
||||
"[safetensor] unsupported dtype " + std::string(str));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,14 +130,14 @@ SafetensorsLoad load_safetensors(
|
||||
}
|
||||
continue;
|
||||
}
|
||||
std::string dtype = item.value().at("dtype");
|
||||
std::vector<int> shape = item.value().at("shape");
|
||||
std::vector<size_t> data_offsets = item.value().at("data_offsets");
|
||||
const std::string& dtype = item.value().at("dtype");
|
||||
const std::vector<int>& shape = item.value().at("shape");
|
||||
const std::vector<size_t>& data_offsets = item.value().at("data_offsets");
|
||||
Dtype type = dtype_from_safetensor_str(dtype);
|
||||
auto loaded_array = array(
|
||||
shape,
|
||||
type,
|
||||
std::make_unique<Load>(
|
||||
std::make_shared<Load>(
|
||||
to_stream(s), in_stream, offset + data_offsets.at(0), false),
|
||||
std::vector<array>{});
|
||||
res.insert({item.key(), loaded_array});
|
||||
@@ -207,19 +208,17 @@ void save_safetensors(
|
||||
}
|
||||
|
||||
void save_safetensors(
|
||||
const std::string& file_,
|
||||
std::string file,
|
||||
std::unordered_map<std::string, array> a,
|
||||
std::unordered_map<std::string, std::string> metadata /* = {} */) {
|
||||
// Open and check file
|
||||
std::string file = file_;
|
||||
|
||||
// Add .safetensors to file name if it is not there
|
||||
if (file.length() < 12 ||
|
||||
file.substr(file.length() - 12, 12) != ".safetensors")
|
||||
file += ".safetensors";
|
||||
|
||||
// Serialize array
|
||||
save_safetensors(std::make_shared<io::FileWriter>(file), a, metadata);
|
||||
save_safetensors(
|
||||
std::make_shared<io::FileWriter>(std::move(file)), a, metadata);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -11,7 +11,7 @@
|
||||
namespace mlx::core::linalg {
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
return is_floating_point(d) ? d : promote_types(d, float32);
|
||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||
}
|
||||
|
||||
inline array l2_norm(
|
||||
@@ -19,7 +19,7 @@ inline array l2_norm(
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
if (is_complex(a.dtype())) {
|
||||
if (issubdtype(a.dtype(), complexfloating)) {
|
||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
|
||||
} else {
|
||||
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||
|
285
mlx/ops.cpp
285
mlx/ops.cpp
@@ -47,7 +47,7 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
|
||||
}
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
return is_floating_point(d) ? d : promote_types(d, float32);
|
||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -158,50 +158,64 @@ array linspace(
|
||||
to_stream(s));
|
||||
}
|
||||
|
||||
array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
if (dtype == a.dtype()) {
|
||||
return a;
|
||||
return std::move(a);
|
||||
}
|
||||
auto copied_shape = a.shape(); // |a| will be moved
|
||||
return array(
|
||||
a.shape(), dtype, std::make_shared<AsType>(to_stream(s), dtype), {a});
|
||||
std::move(copied_shape),
|
||||
dtype,
|
||||
std::make_shared<AsType>(to_stream(s), dtype),
|
||||
{std::move(a)});
|
||||
}
|
||||
|
||||
array as_strided(
|
||||
const array& a,
|
||||
array a,
|
||||
std::vector<int> shape,
|
||||
std::vector<size_t> strides,
|
||||
size_t offset,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Force the input array to be contiguous
|
||||
auto x = reshape(a, {-1}, s);
|
||||
auto copied_shape = shape; // |shape| will be moved
|
||||
auto dtype = a.dtype(); // |a| will be moved
|
||||
return array(
|
||||
shape,
|
||||
a.dtype(),
|
||||
std::make_shared<AsStrided>(to_stream(s), shape, strides, offset),
|
||||
{x});
|
||||
std::move(copied_shape),
|
||||
dtype,
|
||||
std::make_shared<AsStrided>(
|
||||
to_stream(s), std::move(shape), std::move(strides), offset),
|
||||
// Force the input array to be contiguous.
|
||||
{reshape(std::move(a), {-1}, s)});
|
||||
}
|
||||
|
||||
array copy(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return array(a.shape(), a.dtype(), std::make_shared<Copy>(to_stream(s)), {a});
|
||||
array copy(array a, StreamOrDevice s /* = {} */) {
|
||||
auto copied_shape = a.shape(); // |a| will be moved
|
||||
auto dtype = a.dtype();
|
||||
return array(
|
||||
std::move(copied_shape),
|
||||
dtype,
|
||||
std::make_shared<Copy>(to_stream(s)),
|
||||
{std::move(a)});
|
||||
}
|
||||
|
||||
array full(
|
||||
const std::vector<int>& shape,
|
||||
const array& vals,
|
||||
std::vector<int> shape,
|
||||
array vals,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {
|
||||
if (std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; })) {
|
||||
throw std::invalid_argument("[full] Negative dimensions not allowed.");
|
||||
}
|
||||
auto in = broadcast_to(astype(vals, dtype, s), shape, s);
|
||||
return array(shape, dtype, std::make_shared<Full>(to_stream(s)), {in});
|
||||
auto copied_shape = shape; // |shape| will be moved
|
||||
return array(
|
||||
std::move(copied_shape),
|
||||
dtype,
|
||||
std::make_shared<Full>(to_stream(s)),
|
||||
{broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)});
|
||||
}
|
||||
|
||||
array full(
|
||||
const std::vector<int>& shape,
|
||||
const array& vals,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return full(shape, vals, vals.dtype(), to_stream(s));
|
||||
array full(std::vector<int> shape, array vals, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = vals.dtype(); // |vals| will be moved
|
||||
return full(std::move(shape), std::move(vals), dtype, to_stream(s));
|
||||
}
|
||||
|
||||
array zeros(
|
||||
@@ -682,6 +696,41 @@ split(const array& a, int num_splits, StreamOrDevice s /* = {} */) {
|
||||
return split(a, num_splits, 0, to_stream(s));
|
||||
}
|
||||
|
||||
std::vector<array> meshgrid(
|
||||
const std::vector<array>& arrays,
|
||||
bool sparse /* = false */,
|
||||
std::string indexing /* = "xy" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (indexing != "xy" && indexing != "ij") {
|
||||
throw std::invalid_argument(
|
||||
"[meshgrid] Invalid indexing value. Valid values are 'xy' and 'ij'.");
|
||||
}
|
||||
|
||||
auto ndim = arrays.size();
|
||||
std::vector<array> outputs;
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
std::vector<int> shape(ndim, 1);
|
||||
shape[i] = -1;
|
||||
outputs.push_back(reshape(arrays[i], std::move(shape), s));
|
||||
}
|
||||
|
||||
if (indexing == "xy" and ndim > 1) {
|
||||
std::vector<int> shape(ndim, 1);
|
||||
|
||||
shape[1] = arrays[0].size();
|
||||
outputs[0] = reshape(arrays[0], shape, s);
|
||||
shape[1] = 1;
|
||||
shape[0] = arrays[1].size();
|
||||
outputs[1] = reshape(arrays[1], std::move(shape), s);
|
||||
}
|
||||
|
||||
if (!sparse) {
|
||||
outputs = broadcast_arrays(outputs, s);
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
array clip(
|
||||
const array& a,
|
||||
const std::optional<array>& a_min,
|
||||
@@ -883,15 +932,15 @@ array pad(
|
||||
if (low_pad_size[i] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid low padding size (" << low_pad_size[i]
|
||||
<< ") passed to pad"
|
||||
<< " for axis " << i << ". Padding sizes must be non-negative";
|
||||
<< ") passed to pad" << " for axis " << i
|
||||
<< ". Padding sizes must be non-negative";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (high_pad_size[i] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid high padding size (" << high_pad_size[i]
|
||||
<< ") passed to pad"
|
||||
<< " for axis " << i << ". Padding sizes must be non-negative";
|
||||
<< ") passed to pad" << " for axis " << i
|
||||
<< ". Padding sizes must be non-negative";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@@ -1001,19 +1050,30 @@ array transpose(
|
||||
for (auto& ax : axes) {
|
||||
ax = ax < 0 ? ax + a.ndim() : ax;
|
||||
}
|
||||
std::set dims(axes.begin(), axes.end());
|
||||
if (dims.size() != axes.size()) {
|
||||
throw std::invalid_argument("Repeat axes not allowed in transpose.");
|
||||
if (axes.size() != a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[transpose] Recived " << axes.size() << " axes for array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (dims.size() != a.ndim() ||
|
||||
a.ndim() > 0 &&
|
||||
(*dims.begin() != 0 || *dims.rbegin() != (a.ndim() - 1))) {
|
||||
throw std::invalid_argument("Transpose axes don't match array dimensions.");
|
||||
|
||||
// Check in bounds and for duplicates
|
||||
std::vector<int> shape(axes.size(), 0);
|
||||
for (auto& ax : axes) {
|
||||
if (ax < 0 || ax >= a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[transpose] Invalid axis (" << ax << ") for array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (shape[ax] != 0) {
|
||||
throw std::invalid_argument("[transpose] Repeat axes not allowed.");
|
||||
}
|
||||
shape[ax] = 1;
|
||||
}
|
||||
std::vector<int> shape;
|
||||
shape.reserve(axes.size());
|
||||
for (auto ax : axes) {
|
||||
shape.push_back(a.shape()[ax]);
|
||||
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
shape[i] = a.shape()[axes[i]];
|
||||
}
|
||||
return array(
|
||||
std::move(shape),
|
||||
@@ -1140,7 +1200,7 @@ array array_equal(
|
||||
return array(false);
|
||||
} else {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
equal_nan &= is_floating_point(dtype);
|
||||
equal_nan &= issubdtype(dtype, inexact);
|
||||
return all(
|
||||
array(
|
||||
a.shape(),
|
||||
@@ -1153,7 +1213,7 @@ array array_equal(
|
||||
}
|
||||
|
||||
array isnan(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (is_integral(a.dtype())) {
|
||||
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return not_equal(a, a, s);
|
||||
@@ -1164,14 +1224,14 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
|
||||
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (is_integral(a.dtype())) {
|
||||
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||
}
|
||||
|
||||
array isneginf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (is_integral(a.dtype())) {
|
||||
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||
@@ -1416,6 +1476,34 @@ array var(
|
||||
return var(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
|
||||
}
|
||||
|
||||
array std(
|
||||
const array& a,
|
||||
bool keepdims,
|
||||
int ddof /* = 0*/,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
std::vector<int> axes(a.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
return std(a, axes, keepdims, ddof, to_stream(s));
|
||||
}
|
||||
|
||||
array std(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims /* = false */,
|
||||
int ddof /* = 0*/,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
return sqrt(var(a, axes, keepdims, ddof, s), s);
|
||||
}
|
||||
|
||||
array std(
|
||||
const array& a,
|
||||
int axis,
|
||||
bool keepdims /* = false */,
|
||||
int ddof /* = 0*/,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return std(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
|
||||
}
|
||||
|
||||
array prod(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
|
||||
std::vector<int> axes(a.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
@@ -1929,7 +2017,7 @@ array floor_divide(
|
||||
const array& b,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
if (is_floating_point(dtype)) {
|
||||
if (issubdtype(dtype, inexact)) {
|
||||
return floor(divide(a, b, s), s);
|
||||
}
|
||||
|
||||
@@ -1957,7 +2045,7 @@ array operator%(const array& a, const array& b) {
|
||||
std::vector<array>
|
||||
divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
if (is_complex(dtype)) {
|
||||
if (issubdtype(dtype, complexfloating)) {
|
||||
throw std::invalid_argument("[divmod] Complex type not supported.");
|
||||
}
|
||||
auto inputs =
|
||||
@@ -2019,6 +2107,13 @@ array exp(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return array(a.shape(), dtype, std::make_shared<Exp>(to_stream(s)), {input});
|
||||
}
|
||||
|
||||
array expm1(const array& a, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto input = astype(a, dtype, s);
|
||||
return array(
|
||||
a.shape(), dtype, std::make_shared<Expm1>(to_stream(s)), {input});
|
||||
}
|
||||
|
||||
array sin(const array& a, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto input = astype(a, dtype, s);
|
||||
@@ -2220,7 +2315,7 @@ array matmul(
|
||||
}
|
||||
// Type promotion
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Only real floating point types are supported but "
|
||||
<< a.dtype() << " and " << b.dtype() << " were provided which results"
|
||||
@@ -2330,7 +2425,7 @@ array gather(
|
||||
|
||||
// Promote indices to the same type
|
||||
auto dtype = result_type(indices);
|
||||
if (!is_integral(dtype)) {
|
||||
if (issubdtype(dtype, inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"[gather] Got indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
@@ -2413,8 +2508,8 @@ array take_along_axis(
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[take_along_axis] Received invalid axis "
|
||||
<< " for array with " << a.ndim() << " dimensions.";
|
||||
msg << "[take_along_axis] Received invalid axis " << " for array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@@ -2521,7 +2616,7 @@ array scatter(
|
||||
|
||||
// Promote indices to the same type
|
||||
auto dtype = result_type(indices);
|
||||
if (!is_integral(dtype)) {
|
||||
if (issubdtype(dtype, inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"[scatter] Got indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
@@ -2605,25 +2700,34 @@ array rsqrt(const array& a, StreamOrDevice s /* = {} */) {
|
||||
array softmax(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool precise /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
return array(
|
||||
a.shape(),
|
||||
dtype,
|
||||
std::make_shared<Softmax>(to_stream(s)),
|
||||
std::make_shared<Softmax>(to_stream(s), precise),
|
||||
{astype(a, dtype, s)});
|
||||
} else {
|
||||
auto a_max = stop_gradient(max(a, axes, /*keepdims = */ true, s), s);
|
||||
auto ex = exp(subtract(a, a_max, s), s);
|
||||
return divide(ex, sum(ex, axes, /*keepdims = */ true, s), s);
|
||||
auto in = a;
|
||||
if (precise) {
|
||||
in = astype(a, float32, s);
|
||||
}
|
||||
auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s);
|
||||
auto ex = exp(subtract(in, a_max, s), s);
|
||||
return astype(
|
||||
divide(ex, sum(ex, axes, /*keepdims = */ true, s), s), a.dtype(), s);
|
||||
}
|
||||
}
|
||||
|
||||
array softmax(const array& a, StreamOrDevice s /* = {}*/) {
|
||||
array softmax(
|
||||
const array& a,
|
||||
bool precise /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
std::vector<int> axes(a.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
return softmax(a, axes, s);
|
||||
return softmax(a, axes, precise, s);
|
||||
}
|
||||
|
||||
array power(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
@@ -2800,15 +2904,15 @@ inline std::vector<int> conv_out_shape(
|
||||
|
||||
if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Padding sizes must be non-negative."
|
||||
<< " Got padding " << pads_lo << " | " << pads_hi << ".";
|
||||
msg << "[conv] Padding sizes must be non-negative." << " Got padding "
|
||||
<< pads_lo << " | " << pads_hi << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (strides[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Stride sizes must be positive."
|
||||
<< " Got strides " << strides << ".";
|
||||
msg << "[conv] Stride sizes must be positive." << " Got strides "
|
||||
<< strides << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@@ -2834,7 +2938,7 @@ inline std::vector<int> conv_out_shape(
|
||||
}
|
||||
|
||||
inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
|
||||
if (!is_floating_point(in.dtype()) && kindof(in.dtype()) != Dtype::Kind::c) {
|
||||
if (!issubdtype(in.dtype(), floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid input array with type " << in.dtype() << "."
|
||||
<< " Convolution currently only supports floating point types";
|
||||
@@ -2844,8 +2948,7 @@ inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
|
||||
if (in.ndim() != n_dim + 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid input array with " << in.ndim() << " dimensions for "
|
||||
<< n_dim << "D convolution."
|
||||
<< " Expected an array with " << n_dim + 2
|
||||
<< n_dim << "D convolution." << " Expected an array with " << n_dim + 2
|
||||
<< " dimensions following the format [N, ..., C_in].";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@@ -2971,6 +3074,35 @@ array conv_general(
|
||||
input_dilation = std::vector<int>(spatial_dims, input_dilation_int);
|
||||
}
|
||||
|
||||
// Check for negative padding
|
||||
bool has_neg_padding = false;
|
||||
for (auto& pd : padding_lo) {
|
||||
has_neg_padding = (pd < 0);
|
||||
}
|
||||
for (auto& pd : padding_hi) {
|
||||
has_neg_padding = (pd < 0);
|
||||
}
|
||||
|
||||
// Handle negative padding
|
||||
if (has_neg_padding) {
|
||||
std::vector<int> starts(in.ndim(), 0);
|
||||
std::vector<int> stops = in.shape();
|
||||
|
||||
for (int i = 0; i < spatial_dims; i++) {
|
||||
if (padding_lo[i] < 0) {
|
||||
starts[i + 1] -= padding_lo[i];
|
||||
padding_lo[i] = 0;
|
||||
}
|
||||
|
||||
if (padding_hi[i] < 0) {
|
||||
stops[i + 1] += padding_hi[i];
|
||||
padding_hi[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
in = slice(in, std::move(starts), std::move(stops), s);
|
||||
}
|
||||
|
||||
// Get output shapes
|
||||
std::vector<int> out_shape = conv_out_shape(
|
||||
in.shape(),
|
||||
@@ -3005,7 +3137,6 @@ array quantized_matmul(
|
||||
int bits /* = 4 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
array x = in_x;
|
||||
|
||||
if (w.dtype() != uint32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] The weight matrix should be uint32 "
|
||||
@@ -3022,12 +3153,6 @@ array quantized_matmul(
|
||||
// Keep x's batch dimensions to reshape it back after the matmul
|
||||
auto original_shape = x.shape();
|
||||
int x_inner_dims = original_shape.back();
|
||||
original_shape.pop_back();
|
||||
|
||||
// Reshape x into a matrix if it isn't already one
|
||||
if (x.ndim() != 2) {
|
||||
x = reshape(x, {-1, x_inner_dims}, s);
|
||||
}
|
||||
|
||||
if (scales.ndim() != 2 || scales.shape() != biases.shape()) {
|
||||
std::ostringstream msg;
|
||||
@@ -3062,7 +3187,7 @@ array quantized_matmul(
|
||||
}
|
||||
|
||||
auto dtype = result_type(x, scales, biases);
|
||||
if (!is_floating_point(dtype) || is_complex(dtype)) {
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] Only real floating types are supported but "
|
||||
<< "the passed types where x.dtype() == " << x.dtype()
|
||||
@@ -3070,9 +3195,10 @@ array quantized_matmul(
|
||||
<< " and biases.dtype() == " << biases.dtype();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto out = array(
|
||||
{x.shape(0), w_outer_dims},
|
||||
std::vector<array> inputs;
|
||||
original_shape.back() = w_outer_dims;
|
||||
return array(
|
||||
std::move(original_shape),
|
||||
dtype,
|
||||
std::make_shared<QuantizedMatmul>(
|
||||
to_stream(s), group_size, bits, transpose),
|
||||
@@ -3080,14 +3206,6 @@ array quantized_matmul(
|
||||
w,
|
||||
astype(scales, dtype, s),
|
||||
astype(biases, dtype, s)});
|
||||
|
||||
// If needed reshape x to the original batch shape
|
||||
if (original_shape.size() != 1) {
|
||||
original_shape.push_back(w_outer_dims);
|
||||
out = reshape(out, std::move(original_shape), s);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::tuple<array, array, array> quantize(
|
||||
@@ -3117,8 +3235,7 @@ std::tuple<array, array, array> quantize(
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||
<< "the quantization group size " << group_size
|
||||
<< ". However the provided "
|
||||
<< " matrix has shape " << w.shape();
|
||||
<< ". However the provided " << " matrix has shape " << w.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@@ -3364,7 +3481,7 @@ array addmm(
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b, c);
|
||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Only real floating point types are supported but "
|
||||
<< c.dtype() << ", " << a.dtype() << " and " << b.dtype()
|
||||
|
77
mlx/ops.h
77
mlx/ops.h
@@ -41,40 +41,33 @@ array linspace(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Convert an array to the given data type. */
|
||||
array astype(const array& a, Dtype dtype, StreamOrDevice s = {});
|
||||
array astype(array a, Dtype dtype, StreamOrDevice s = {});
|
||||
|
||||
/** Create a view of an array with the given shape and strides. */
|
||||
array as_strided(
|
||||
const array& a,
|
||||
array a,
|
||||
std::vector<int> shape,
|
||||
std::vector<size_t> strides,
|
||||
size_t offset,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Copy another array. */
|
||||
array copy(const array& a, StreamOrDevice s = {});
|
||||
array copy(array a, StreamOrDevice s = {});
|
||||
|
||||
/** Fill an array of the given shape with the given value(s). */
|
||||
array full(
|
||||
const std::vector<int>& shape,
|
||||
const array& vals,
|
||||
std::vector<int> shape,
|
||||
array vals,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s = {});
|
||||
array full(
|
||||
const std::vector<int>& shape,
|
||||
const array& vals,
|
||||
StreamOrDevice s = {});
|
||||
array full(std::vector<int> shape, array vals, StreamOrDevice s = {});
|
||||
template <typename T>
|
||||
array full(
|
||||
const std::vector<int>& shape,
|
||||
T val,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s = {}) {
|
||||
return full(shape, array(val, dtype), to_stream(s));
|
||||
array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {}) {
|
||||
return full(std::move(shape), array(val, dtype), to_stream(s));
|
||||
}
|
||||
template <typename T>
|
||||
array full(const std::vector<int>& shape, T val, StreamOrDevice s = {}) {
|
||||
return full(shape, array(val), to_stream(s));
|
||||
array full(std::vector<int> shape, T val, StreamOrDevice s = {}) {
|
||||
return full(std::move(shape), array(val), to_stream(s));
|
||||
}
|
||||
|
||||
/** Fill an array of the given shape with zeros. */
|
||||
@@ -204,6 +197,13 @@ std::vector<array> split(
|
||||
std::vector<array>
|
||||
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
|
||||
|
||||
/** A vector of coordinate arrays from coordinate vectors. */
|
||||
std::vector<array> meshgrid(
|
||||
const std::vector<array>& arrays,
|
||||
bool sparse = false,
|
||||
std::string indexing = "xy",
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/**
|
||||
* Clip (limit) the values in an array.
|
||||
*/
|
||||
@@ -514,13 +514,14 @@ array mean(
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the mean of the elements of an array. */
|
||||
/** Computes the variance of the elements of an array. */
|
||||
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
||||
inline array var(const array& a, StreamOrDevice s = {}) {
|
||||
return var(a, false, 0, to_stream(s));
|
||||
}
|
||||
|
||||
/** Computes the var of the elements of an array along the given axes */
|
||||
/** Computes the variance of the elements of an array along the given
|
||||
* axes */
|
||||
array var(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
@@ -528,7 +529,8 @@ array var(
|
||||
int ddof = 0,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the var of the elements of an array along the given axis */
|
||||
/** Computes the variance of the elements of an array along the given
|
||||
* axis */
|
||||
array var(
|
||||
const array& a,
|
||||
int axis,
|
||||
@@ -536,6 +538,30 @@ array var(
|
||||
int ddof = 0,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the standard deviation of the elements of an array. */
|
||||
array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
||||
inline array std(const array& a, StreamOrDevice s = {}) {
|
||||
return std(a, false, 0, to_stream(s));
|
||||
}
|
||||
|
||||
/** Computes the standard deviatoin of the elements of an array along the given
|
||||
* axes */
|
||||
array std(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims = false,
|
||||
int ddof = 0,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the standard deviation of the elements of an array along the given
|
||||
* axis */
|
||||
array std(
|
||||
const array& a,
|
||||
int axis,
|
||||
bool keepdims = false,
|
||||
int ddof = 0,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** The product of all elements of the array. */
|
||||
array prod(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||
inline array prod(const array& a, StreamOrDevice s = {}) {
|
||||
@@ -849,6 +875,9 @@ array erf(const array& a, StreamOrDevice s = {});
|
||||
/** Computes the inverse error function of the elements of an array. */
|
||||
array erfinv(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Computes the expm1 function of the elements of an array. */
|
||||
array expm1(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Stop the flow of gradients. */
|
||||
array stop_gradient(const array& a, StreamOrDevice s = {});
|
||||
|
||||
@@ -983,14 +1012,16 @@ array rsqrt(const array& a, StreamOrDevice s = {});
|
||||
array softmax(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool precise = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Softmax of an array. */
|
||||
array softmax(const array& a, StreamOrDevice s = {});
|
||||
array softmax(const array& a, bool precise = false, StreamOrDevice s = {});
|
||||
|
||||
/** Softmax of an array. */
|
||||
inline array softmax(const array& a, int axis, StreamOrDevice s = {}) {
|
||||
return softmax(a, std::vector<int>{axis}, s);
|
||||
inline array
|
||||
softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
|
||||
return softmax(a, std::vector<int>{axis}, precise, s);
|
||||
}
|
||||
|
||||
/** Raise elements of a to the power of b element-wise */
|
||||
|
@@ -1239,6 +1239,34 @@ std::pair<std::vector<array>, std::vector<int>> Exp::vmap(
|
||||
return {{exp(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> Expm1::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
return {multiply(
|
||||
cotangents[0],
|
||||
add(outputs[0], array(1.0f, outputs[0].dtype()), stream()),
|
||||
stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Expm1::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return {multiply(tangents[0], exp(primals[0], stream()), stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Expm1::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
return {{expm1(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
bool FFT::is_equivalent(const Primitive& other) const {
|
||||
const FFT& r_other = static_cast<const FFT&>(other);
|
||||
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&
|
||||
@@ -1267,7 +1295,7 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
|
||||
{array(
|
||||
out_shape,
|
||||
real_ && inverse_ ? float32 : complex64,
|
||||
std::make_unique<FFT>(stream(), fft_axes, inverse_, real_),
|
||||
std::make_shared<FFT>(stream(), fft_axes, inverse_, real_),
|
||||
{in})},
|
||||
{ax}};
|
||||
}
|
||||
@@ -1377,7 +1405,7 @@ std::pair<std::vector<array>, std::vector<int>> Full::vmap(
|
||||
assert(axes.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto out =
|
||||
array(in.shape(), in.dtype(), std::make_unique<Full>(stream()), {in});
|
||||
array(in.shape(), in.dtype(), std::make_shared<Full>(stream()), {in});
|
||||
return {{out}, axes};
|
||||
}
|
||||
|
||||
@@ -1604,7 +1632,7 @@ std::pair<std::vector<array>, std::vector<int>> Log::vmap(
|
||||
{array(
|
||||
in.shape(),
|
||||
in.dtype(),
|
||||
std::make_unique<Log>(stream(), base_),
|
||||
std::make_shared<Log>(stream(), base_),
|
||||
{in})},
|
||||
axes};
|
||||
}
|
||||
@@ -2259,7 +2287,7 @@ std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap(
|
||||
auto out = array(
|
||||
shape,
|
||||
get_dtype(),
|
||||
std::make_unique<RandomBits>(stream(), shape, width_),
|
||||
std::make_shared<RandomBits>(stream(), shape, width_),
|
||||
{key});
|
||||
return {{out}, {kax}};
|
||||
}
|
||||
@@ -2493,7 +2521,7 @@ std::pair<std::vector<array>, std::vector<int>> Scan::vmap(
|
||||
{array(
|
||||
in.shape(),
|
||||
out_dtype,
|
||||
std::make_unique<Scan>(
|
||||
std::make_shared<Scan>(
|
||||
stream(), reduce_type_, axis_ + axis_left, reverse_, inclusive_),
|
||||
{in})},
|
||||
axes};
|
||||
@@ -2975,7 +3003,7 @@ std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
|
||||
} else {
|
||||
softmax_axes.push_back(-2);
|
||||
}
|
||||
return {{softmax(inputs[0], softmax_axes, stream())}, axes};
|
||||
return {{softmax(inputs[0], softmax_axes, precise_, stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> Softmax::vjp(
|
||||
@@ -2998,13 +3026,18 @@ std::vector<array> Softmax::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(tangents.size() == 1);
|
||||
auto s = softmax(primals[0], std::vector<int>{-1}, stream());
|
||||
auto s = softmax(primals[0], std::vector<int>{-1}, precise_, stream());
|
||||
auto sv = multiply(s, tangents[0], stream());
|
||||
return {subtract(
|
||||
sv,
|
||||
multiply(s, sum(sv, std::vector<int>{-1}, true, stream()), stream()))};
|
||||
}
|
||||
|
||||
bool Softmax::is_equivalent(const Primitive& other) const {
|
||||
const Softmax& s_other = static_cast<const Softmax&>(other);
|
||||
return precise_ == s_other.precise_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Sort::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
@@ -3303,7 +3336,7 @@ std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
|
||||
array out = array(
|
||||
std::vector<int>{},
|
||||
dtype_,
|
||||
std::make_unique<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
|
||||
std::make_shared<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
|
||||
inputs);
|
||||
|
||||
return {{out}, {-1}};
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user