Compare commits

...

51 Commits

Author SHA1 Message Date
Christopher Fleetwood
004c1d8ef2 Report number of missing parameters (#2264)
* chore: inform

* chore: format

---------

Co-authored-by: FL33TW00D <FL33TW00D@users.noreply.github.com>
2025-06-10 06:37:50 -07:00
Cheng
7ebb2e0193 CUDA backend: binary ops (#2259) 2025-06-10 06:37:40 -07:00
Awni Hannun
9ce77798b1 fix export to work with gather/scatter axis (#2263) 2025-06-09 20:37:27 -07:00
Cheng
f8bad60609 CUDA backend: unary ops (#2158) 2025-06-09 06:45:08 -07:00
Emmanuel Ferdman
5866b3857b Refactor the lu test (#2250)
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-06-07 06:12:08 -07:00
Awni Hannun
1ca616844b Fix unintuitive metal kernel caching (#2242)
* Fix unintuitive metal kernel caching

* alternative solution
2025-06-06 20:08:15 -07:00
Angelos Katharopoulos
2e8cf0b450 Change layernorms to two pass algorithm (#2246) 2025-06-06 13:34:56 -07:00
Cheng
24f89173d1 CUDA backend: matmul (#2241) 2025-06-06 12:24:04 -07:00
Awni Hannun
c6a20b427a Improve metal elementwise kernels (#2247)
* improve metal elementwise kernels

* compile and copy

* fix jit
2025-06-06 11:37:40 -07:00
Awni Hannun
a5ac9244c4 fix linux linking error (#2248) 2025-06-06 10:41:51 -07:00
Awni Hannun
c763fe1be0 default strict mode for module update and update_modules (#2239) 2025-06-05 15:27:02 -07:00
Cheng
52dc8c8cd5 Add profiler annotations in common primitives for CUDA backend (#2244) 2025-06-04 19:55:12 -07:00
Angelos Katharopoulos
aede70e81d Perf regression fix (#2243) 2025-06-03 17:55:12 -07:00
Cheng
85a8beb5e4 Avoid atomic updates across CPU/GPU in CUDA event (#2231) 2025-06-03 16:49:06 -07:00
Cheng
0bb89e9e5f Share more common code in Compiled (#2240)
* Share more common code in Compiled

* Remove build_lib_name
2025-06-03 16:48:50 -07:00
Cheng
5685ceb3c7 Avoid invoking allocator::malloc when creating CUDA event (#2232) 2025-06-03 16:48:40 -07:00
Suryash Malviya
0408ba0a76 Optimizing Complex Matrix Multiplication using Karatsuba’s Algorithm (#2220)
* Implementing Complex Matmul using Karatsuba Algorithm

* Implemented Karatsuba's Algorithm for complex matmul and pre-commit them

* fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-02 15:58:46 -07:00
Awni Hannun
cbad6c3093 version (#2237) 2025-06-02 15:58:33 -07:00
Cheng
1b021f6984 Fast primitives decide when to use the fallback (#2216) 2025-06-02 13:26:37 -07:00
Cheng
95b7551d65 Do not check event.is_signaled() in eval_impl (#2230) 2025-06-02 13:23:34 -07:00
Cheng
db5a7c6192 Add memory cache to CUDA backend (#2221)
* Move BufferCache out of allocator

* Add memory cache to cuda backend allocator

* Simplify BufferCache assuming buf can not be null
2025-05-30 12:12:54 -07:00
Awni Hannun
6ef2f67e7f 5bit quants (#2226)
* 5bit quants

* 5bit quants
2025-05-30 12:12:10 -07:00
Cheng
f76ee1ffd2 Move some dims utils to common (#2223) 2025-05-29 06:48:30 -07:00
Cheng
54a71f270a Remove unused defines (#2217) 2025-05-23 06:14:58 -07:00
Awni Hannun
55b4062dd8 copyright in docs (#2214) 2025-05-21 17:13:04 -07:00
Cheng
79071bfba4 Fix out-of-bounds default value in logsumexp/softmax (#2213) 2025-05-21 07:25:16 -07:00
Cheng
7774b87cbd Remove redundant simd_sum in logsumexp (#2210) 2025-05-21 07:25:03 -07:00
Cheng
35c87741cf Build for compute capability 70 instead of 75 (#2209) 2025-05-20 19:42:48 -07:00
Jack Wind
4cbe605214 Feat: Allow per-target Metal debug flags (#2201)
* feat: allow per-target Metal debug flags

* formatting fix
2025-05-20 10:22:26 -07:00
Clement Liaw
ab8883dd55 include mlx::core::version() symbols in the mlx static library (#2207) 2025-05-20 07:39:11 -07:00
Awni Hannun
eebe73001a fix large arg reduce (#2206) 2025-05-19 13:10:44 -07:00
Angelos Katharopoulos
0359bf02c9 Nearest upsample (#2202) 2025-05-19 11:23:38 -07:00
Cheng
237f9e58a8 Fix BEFORE keyword in target_include_directories (#2204) 2025-05-19 06:10:44 -07:00
Awni Hannun
8576e6fe36 fix conv2d bug + faster conv 1d (#2195)
* fix conv2d bug + faster conv 1d

* revert sort + flaky test
2025-05-18 06:05:11 -07:00
Angelos Katharopoulos
0654543dcc Add complex eigh (#2191) 2025-05-18 00:18:43 -07:00
Awni Hannun
48ef3e74e2 reduce vjp for all and any (#2193) 2025-05-16 08:38:49 -07:00
Cheng
7d4b378952 Include cuda_bf16.h for bfloat16 overloads (#2192)
* Include cuda_bf16.h for bfloat16 overloads

* Add NO_GPU_MULTI(Eig) in cuda backend
2025-05-16 06:44:42 -07:00
Jack Wind
7ff5c41e06 Add set_threadgroup_memory_length to CommandEncoder (#2183) 2025-05-16 00:28:03 -07:00
Awni Hannun
602f43e3d1 fix conv grad (#2187) 2025-05-15 19:20:36 -07:00
Awni Hannun
a2cadb8218 real and imag properties (#2189) 2025-05-15 18:17:50 -07:00
Awni Hannun
c1eb9d05d9 non-symmetric eig and eigh (#2188) 2025-05-15 13:01:44 -07:00
Angelos Katharopoulos
cf6c939e86 Fix some complex vjps (#2178) 2025-05-14 23:37:12 -07:00
Angelos Katharopoulos
130df35e1b Add random normal distribution for complex numbers (#2182) 2025-05-13 22:43:45 -07:00
Cheng
0751263dec Fix typo in row_reduce_small (#2179) 2025-05-13 20:19:54 -07:00
Cheng
eca2f3eb97 Add remove_index utility (#2173) 2025-05-13 17:09:56 -07:00
Angelos Katharopoulos
3aa9cf3f9e Fix put_along_axis for empty arrays (#2181) 2025-05-13 14:27:53 -07:00
Awni Hannun
8f3d208dce Close a couple edge case bugs: hadamard and addmm on empty inputs (#2177)
* handle hadamard and addmm on empty inputs

* fix
2025-05-12 10:48:57 -07:00
Ivan Fioravanti
caaa3f1f8c Small typos in mx.metal deprecations (#2176) 2025-05-11 06:03:47 -07:00
Awni Hannun
659a51919f patch bump (#2162) 2025-05-09 14:35:14 -07:00
Awni Hannun
6661387066 Fix fft for integer overflow (#2161) 2025-05-09 14:25:12 -07:00
ATurker
a7fae8a176 fix: conv_general differences between gpu, cpu (#2070)
* fix general_conv padding

* fix bugs

* add test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-05-09 10:26:52 -07:00
133 changed files with 6667 additions and 2688 deletions

View File

@@ -231,6 +231,9 @@ target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>)
# Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git

View File

@@ -1,5 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from functools import partial
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
@@ -18,51 +20,63 @@ def layer_norm(x, w, b, eps):
return y
def time_layer_norm():
def time_layer_norm(N, dt):
L = 1024
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
mx.eval(x, w, b, y)
def layer_norm_loop(g, x, w, b):
def layer_norm_loop(f, x, w, b):
for _ in range(32):
x = f(x, w, b)
return x
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
def layer_norm_grad_loop(g, x, w, b):
gx, gw, gb = x, w, b
for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb
time_fn(layer_norm_loop, g1, x, w, b)
time_fn(layer_norm_loop, g2, x, w, b)
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
time_fn(layer_norm_grad_loop, g1, x, w, b)
time_fn(layer_norm_grad_loop, g2, x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
mx.eval(x, w, b, y)
def layer_norm_loop(g, x):
def layer_norm_grad_x_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(layer_norm_loop, g1, x)
time_fn(layer_norm_loop, g2, x)
time_fn(layer_norm_loop, mx.compile(g1), x)
time_fn(layer_norm_loop, mx.compile(g2), x)
time_fn(layer_norm_grad_x_loop, g1, x)
time_fn(layer_norm_grad_x_loop, g2, x)
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
if __name__ == "__main__":
time_layer_norm()
for dt in [mx.float32, mx.float16, mx.bfloat16]:
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
print(dt, n)
time_layer_norm(n, dt)

View File

@@ -11,13 +11,14 @@ include(CMakeParseArguments)
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# files (like headers)
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
#
# clang format on
macro(mlx_build_metallib)
# Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
@@ -26,6 +27,10 @@ macro(mlx_build_metallib)
# Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
-frecord-sources)
endif()
# Prepare metallib build command
add_custom_command(

View File

@@ -10,7 +10,7 @@ import mlx.core as mx
# -- Project information -----------------------------------------------------
project = "MLX"
copyright = "2023, MLX Contributors"
copyright = "2023, Apple"
author = "MLX Contributors"
version = ".".join(mx.__version__.split(".")[:3])
release = version

View File

@@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example
--------------
.. currentmodule:: mlx.core
Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
def exp_elementwise(a: mx.array):
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
@@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Every time you make a kernel, a new Metal library is created and possibly
JIT compiled. To reduce the overhead from that, build the kernel once with
:func:`fast.metal_kernel` and then use it many times.
.. note::
We are only required to pass the body of the Metal kernel in ``source``.
Only pass the body of the Metal kernel in ``source``. The function
signature is generated automatically.
The full function signature will be generated using:
@@ -78,44 +86,51 @@ Putting this all together, the generated function signature for ``myexp`` is as
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
``threadgroup`` size threadgroups. For optimal performance, each thread group
dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
generated code for debugging purposes.
Using Shape/Strides
-------------------
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
when indexing.
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
is ``True`` by default. This will copy the array inputs if needed
before the kernel is launched to ensure that the memory layout is row
contiguous. Generally this makes writing the kernel easier, since we don't
have to worry about gaps or the ordering of the dims when indexing.
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
input array ``a`` if any are present in ``source``.
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
present in ``source``. We can then use MLX's built in indexing utils to fetch
the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
Let's convert ``myexp`` above to support arbitrarily strided arrays without
relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
@@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops:
.. code-block:: python
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output
return output
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel:
.. code-block:: python
@mx.custom_function
def grid_sample(x, grid):
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
assert D == 2, "Last dim of `grid` must be size 2."
int ix_nw = floor(ix);
int iy_nw = floor(iy);
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
@mx.custom_function
def grid_sample(x, grid):
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs[0]
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
assert D == 2, "Last dim of `grid` must be size 2."
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs[0]
For a reasonably sized input such as:
.. code-block:: python
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement:
@@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
Grid Sample VJP
---------------
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
its custom vjp transform so MLX can differentiate it.
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
define its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra ``mx.fast.metal_kernel`` features:
requires a few extra :func:`fast.metal_kernel` features:
* ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
@@ -299,128 +316,129 @@ We can then implement the backwards pass as follows:
.. code-block:: python
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
assert D == 2, "Last dim of `grid` must be size 2."
int gH = grid_shape[1];
int gW = grid_shape[2];
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
int gH = grid_shape[1];
int gW = grid_shape[2];
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T gix_mult = W / 2;
T giy_mult = H / 2;
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
T gix_mult = W / 2;
T giy_mult = H / 2;
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
assert D == 2, "Last dim of `grid` must be size 2."
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
There's an even larger speed up for the vjp:

View File

@@ -397,11 +397,11 @@ below.
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Load the metal library
auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -19,6 +19,8 @@ Array
array.ndim
array.shape
array.size
array.real
array.imag
array.abs
array.all
array.any

View File

@@ -16,6 +16,8 @@ Linear Algebra
cross
qr
svd
eigvals
eig
eigvalsh
eigh
lu

View File

@@ -172,11 +172,11 @@ void Axpby::eval_gpu(
kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Load the metal library
auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -21,7 +21,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
# Define MLX_VERSION only in the version.cpp file.
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)

View File

@@ -224,6 +224,10 @@ class array {
// Not copyable
Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete;
Data(Data&& o) : buffer(o.buffer), d(o.d) {
o.buffer = allocator::Buffer(nullptr);
o.d = [](allocator::Buffer) {};
}
~Data() {
d(buffer);
}

View File

@@ -0,0 +1,157 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cassert>
#include <functional>
#include <map>
namespace mlx::core {
template <typename T>
class BufferCache {
public:
BufferCache(
size_t page_size,
std::function<size_t(T*)> get_size,
std::function<void(T*)> free)
: page_size_(page_size),
get_size_(std::move(get_size)),
free_(std::move(free)) {}
~BufferCache() {
clear();
}
BufferCache(const BufferCache&) = delete;
BufferCache& operator=(const BufferCache&) = delete;
T* reuse_from_cache(size_t size) {
// Find the closest buffer in pool.
auto it = buffer_pool_.lower_bound(size);
if (it == buffer_pool_.end() ||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
return nullptr;
}
// Collect from the cache.
T* buf = it->second->buf;
pool_size_ -= it->first;
// Remove from record.
remove_from_list(it->second);
buffer_pool_.erase(it);
return buf;
}
void recycle_to_cache(T* buf) {
assert(buf);
// Add to cache.
BufferHolder* bh = new BufferHolder(buf);
add_at_head(bh);
size_t size = get_size_(buf);
pool_size_ += size;
buffer_pool_.emplace(size, bh);
}
int release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) {
return clear();
} else {
int n_release = 0;
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
// Release buffer.
size_t size = get_size_(tail_->buf);
total_bytes_freed += size;
free_(tail_->buf);
n_release++;
// Remove from record.
auto its = buffer_pool_.equal_range(size);
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
return el.second == tail_;
});
assert(it != buffer_pool_.end());
buffer_pool_.erase(it);
remove_from_list(tail_);
}
pool_size_ -= total_bytes_freed;
return n_release;
}
}
int clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
free_(holder->buf);
n_release++;
delete holder;
}
buffer_pool_.clear();
pool_size_ = 0;
head_ = nullptr;
tail_ = nullptr;
return n_release;
}
size_t cache_size() const {
return pool_size_;
}
size_t page_size() const {
return page_size_;
}
private:
struct BufferHolder {
public:
explicit BufferHolder(T* buf_) : buf(buf_) {}
BufferHolder* prev{nullptr};
BufferHolder* next{nullptr};
T* buf;
};
void add_at_head(BufferHolder* to_add) {
if (!head_) {
head_ = to_add;
tail_ = to_add;
} else {
head_->prev = to_add;
to_add->next = head_;
head_ = to_add;
}
}
void remove_from_list(BufferHolder* to_remove) {
if (to_remove->prev && to_remove->next) { // if middle
to_remove->prev->next = to_remove->next;
to_remove->next->prev = to_remove->prev;
} else if (to_remove->prev && to_remove == tail_) { // if tail
tail_ = to_remove->prev;
tail_->next = nullptr;
} else if (to_remove == head_ && to_remove->next) { // if head
head_ = to_remove->next;
head_->prev = nullptr;
} else if (to_remove == head_ && to_remove == tail_) { // if only element
head_ = nullptr;
tail_ = nullptr;
}
delete to_remove;
}
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_{nullptr};
BufferHolder* tail_{nullptr};
size_t pool_size_{0};
const size_t page_size_;
std::function<size_t(T*)> get_size_;
std::function<void(T*)> free_;
};
} // namespace mlx::core

View File

@@ -1,8 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/backend/common/utils.h"
#include "mlx/utils.h"
namespace mlx::core {
@@ -79,55 +78,6 @@ std::string get_type_string(Dtype d) {
}
}
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids) {
NodeNamer namer;
std::ostringstream os;
std::ostringstream constant_hasher;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (auto& x : inputs) {
namer.get_name(x);
}
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (auto& a : tape) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
return os.str();
}
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const Shape& shape) {
@@ -159,8 +109,7 @@ bool compiled_check_contiguity(
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
const std::function<bool(size_t)>& is_constant,
bool contiguous) {
if (contiguous) {
int o = 0;
@@ -175,8 +124,7 @@ void compiled_allocate_outputs(
// - Donatable
// - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
in.is_donatable() && is_constant(i)) {
outputs[o++].copy_shared_buffer(in);
}
// Get representative input flags to properly set non-donated outputs
@@ -204,7 +152,7 @@ void compiled_allocate_outputs(
// - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
is_constant(i)) {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
o++;
@@ -216,4 +164,74 @@ void compiled_allocate_outputs(
}
}
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
const std::vector<array>& inputs,
const array& out,
const std::function<bool(size_t)>& is_constant) {
const Shape& shape = out.shape();
bool contiguous = compiled_check_contiguity(inputs, shape);
if (contiguous) {
return {true, shape, {}};
}
std::vector<Strides> strides_vec{out.strides()};
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants.
if (is_constant(i)) {
continue;
}
// Skip scalar inputs.
const auto& x = inputs[i];
if (is_scalar(x)) {
continue;
}
// Broadcast the inputs to the output shape.
Strides xstrides;
size_t j = 0;
for (; j < shape.size() - x.ndim(); ++j) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides_vec.push_back(std::move(xstrides));
}
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
}
bool compiled_use_large_index(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
bool contiguous) {
if (contiguous) {
size_t max_size = 0;
for (const auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
return max_size > UINT32_MAX;
} else {
size_t max_size = 0;
for (const auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
return max_size > UINT32_MAX;
}
}
} // namespace mlx::core

View File

@@ -1,9 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <functional>
#include <iomanip>
#include <sstream>
#include <unordered_set>
#include "mlx/array.h"
#include "mlx/primitives.h"
@@ -14,12 +13,6 @@ inline bool is_static_cast(const Primitive& p) {
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
}
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids);
std::string get_type_string(Dtype d);
template <typename T>
@@ -60,8 +53,19 @@ bool compiled_check_contiguity(
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
const std::function<bool(size_t)>& is_constant,
bool contiguous);
// Collapse contiguous dims ignoring scalars and constants.
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
const std::vector<array>& inputs,
const array& out,
const std::function<bool(size_t)>& is_constant);
// Return whether the kernel should use large index.
bool compiled_use_large_index(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
bool contiguous);
} // namespace mlx::core

View File

@@ -2,7 +2,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
@@ -26,7 +26,7 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
return true;
} else {

View File

@@ -0,0 +1,78 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/common/utils.h"
#include "mlx/utils.h"
#include <sstream>
namespace mlx::core {
inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a,
const array& b) {
// Get and check the shape for the batched dims
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
if (A_bshape != B_bshape) {
std::ostringstream msg;
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ".";
throw std::runtime_error(msg.str());
}
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
auto a_batch_strides = batch_strides[0];
auto b_batch_strides = batch_strides[1];
if (batch_shape.empty()) {
batch_shape.push_back(1);
a_batch_strides.push_back(0);
b_batch_strides.push_back(0);
}
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
}
inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) {
// Get and check the shape for the batched dims
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
if (A_bshape != B_bshape || A_bshape != C_bshape) {
std::ostringstream msg;
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
throw std::runtime_error(msg.str());
}
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];
auto C_batch_stride = batch_strides[2];
if (batch_shape.empty()) {
batch_shape.push_back(1);
A_batch_stride.push_back(0);
B_batch_stride.push_back(0);
C_batch_stride.push_back(0);
}
return std::make_tuple(
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
}
} // namespace mlx::core

View File

@@ -0,0 +1,26 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
inline void set_unary_output_data(const array& in, array& out) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
}
} // namespace mlx::core

View File

@@ -1,9 +1,16 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
std::string get_primitive_string(Primitive* primitive) {
std::ostringstream op_t;
primitive->print(op_t);
return op_t.str();
}
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const Shape& shape,
const std::vector<Strides>& strides,
@@ -101,4 +108,105 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
}
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
int pows[3] = {0, 0, 0};
int sum = 0;
while (true) {
int presum = sum;
// Check all the pows
if (dim0 >= (1 << (pows[0] + 1))) {
pows[0]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim1 >= (1 << (pows[1] + 1))) {
pows[1]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim2 >= (1 << (pows[2] + 1))) {
pows[2]++;
sum++;
}
if (sum == presum || sum == pow2) {
break;
}
}
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
}
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
// Dims with strides of 0 are ignored as they
// correspond to broadcasted dimensions
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape.");
}
if (grid_y > grid_x) {
std::swap(grid_x, grid_y);
}
return std::make_tuple(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
Dims get_2d_grid_dims_common(
const Shape& shape,
const Strides& strides,
size_t divisor) {
// Compute the 2d grid dimensions such that the total size of the grid is
// divided by divisor.
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
// No need to add this shape we can just remove it from the divisor.
if (divisor % shape[i] == 0) {
divisor /= shape[i];
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
if (divisor > 1) {
if (grid_x % divisor == 0) {
grid_x /= divisor;
divisor = 1;
} else if (grid_y % divisor == 0) {
grid_y /= divisor;
divisor = 1;
}
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
throw std::runtime_error("Unable to safely factor shape.");
}
if (grid_y > grid_x) {
std::swap(grid_x, grid_y);
}
return std::make_tuple(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
} // namespace mlx::core

View File

@@ -2,12 +2,15 @@
#pragma once
#include <tuple>
#include <vector>
#include "mlx/array.h"
namespace mlx::core {
std::string get_primitive_string(Primitive* primitive);
inline int64_t
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
int64_t loc = 0;
@@ -70,6 +73,28 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
const array& a,
int64_t size_cap = std::numeric_limits<int32_t>::max());
// Compute the thread block dimensions which fit the given
// input dimensions.
// - The thread block dimensions will be powers of two
// - The thread block size will be less than 2^pow2
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
// Computes a 2D grid where each element is < UINT_MAX
// Assumes:
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
// - shape and strides correspond to a contiguous (no holes) but
// possibly broadcasted array
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
// Same as above but we do an implicit division with divisor.
// Basically, equivalent to factorizing
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
Dims get_2d_grid_dims_common(
const Shape& shape,
const Strides& strides,
size_t divisor);
struct ContiguousIterator {
inline void step() {
int dims = shape_.size();
@@ -165,4 +190,11 @@ void shared_buffer_reshape(
const array& in,
const Strides& out_strides,
array& out);
template <typename T>
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));
return vec;
}
} // namespace mlx::core

View File

@@ -46,6 +46,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp

View File

@@ -14,10 +14,8 @@ template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis];
Strides strides = in.strides();
Shape shape = in.shape();
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
Strides strides = remove_index(in.strides(), axis);
Shape shape = remove_index(in.shape(), axis);
auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>();

View File

@@ -146,18 +146,9 @@ inline void build_kernel(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids,
const std::function<bool(size_t)>& is_constant,
bool contiguous,
int ndim) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer;
#ifdef _MSC_VER
@@ -170,14 +161,15 @@ inline void build_kernel(
// Add the input arguments
int cnt = 0;
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants from the input list
if (is_constant(x)) {
if (is_constant(i)) {
continue;
}
const auto& x = inputs[i];
auto& xname = namer.get_name(x);
auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
<< "];" << std::endl;
@@ -211,10 +203,11 @@ inline void build_kernel(
}
// Read the inputs in tmps
for (auto& x : inputs) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
auto& xname = namer.get_name(x);
if (is_constant(x)) {
if (is_constant(i)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x);
os << ";" << std::endl;
@@ -264,8 +257,9 @@ inline void build_kernel(
} else {
for (int d = ndim - 1; d >= 0; --d) {
// Update pointers
for (auto& x : inputs) {
if (is_constant(x) || is_scalar(x)) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
if (is_constant(i) || is_scalar(x)) {
continue;
}
auto& xname = namer.get_name(x);
@@ -287,65 +281,37 @@ inline void build_kernel(
void Compiled::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape);
auto& encoder = cpu::get_command_encoder(stream());
// Handle all broadcasting and collect function input arguments
// Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Collect function input arguments.
std::vector<void*> args;
std::vector<std::vector<size_t>> strides;
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant_(i)) {
continue;
}
auto& x = inputs[i];
const auto& x = inputs[i];
encoder.set_input_array(x);
args.push_back((void*)x.data<void>());
if (contiguous || is_scalar(x)) {
continue;
if (!contiguous && !is_scalar(x)) {
args.push_back(strides[strides_index++].data());
}
// Broadcast the input to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < shape.size() - x.ndim(); j++) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides.push_back(std::move(xstrides));
args.push_back(strides.back().data());
}
// Get the kernel name from the lib
int ndim = shape.size();
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) {
kernel_name += std::to_string(shape.size());
kernel_name += std::to_string(ndim);
}
// Get the function
auto fn_ptr = compile(kernel_name, [&]() {
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl;
@@ -355,7 +321,7 @@ void Compiled::eval_cpu(
inputs_,
outputs_,
tape_,
constant_ids_,
is_constant_,
contiguous,
ndim);
// Close extern "C"
@@ -363,26 +329,22 @@ void Compiled::eval_cpu(
return kernel.str();
});
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous);
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
for (auto& x : outputs) {
args.push_back(x.data<void>());
encoder.set_output_array(x);
}
Shape out_shape;
if (!contiguous) {
out_shape = outputs[0].shape();
args.push_back((void*)out_shape.data());
args.push_back((void*)shape.data());
} else {
args.push_back((void*)outputs[0].data_size());
}
auto fun = (void (*)(void**))fn_ptr;
encoder.dispatch(
[fun,
args = std::move(args),
strides = std::move(strides),
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
encoder.dispatch([fun,
args = std::move(args),
strides = std::move(strides),
shape = std::move(shape)]() mutable { fun(args.data()); });
}
} // namespace mlx::core

View File

@@ -22,7 +22,8 @@ void slow_conv_1D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -60,7 +61,8 @@ void slow_conv_1D(
out_stride_O = out.strides()[2],
flip,
padding = padding[0],
padding_lo = padding_lo[0],
padding_hi = padding_hi[0],
wt_stride = wt_strides[0],
wt_dilation = wt_dilation[0],
in_dilation = in_dilation[0]]() mutable {
@@ -77,7 +79,7 @@ void slow_conv_1D(
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_stride - padding + wh_flip * wt_dilation;
int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation;
auto ih_div = std::div(ih, in_dilation);
@@ -109,7 +111,8 @@ void slow_conv_2D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -120,230 +123,235 @@ void slow_conv_2D(
encoder.set_input_array(wt);
encoder.set_output_array(out);
encoder.dispatch([st_wt_ptr = wt.data<T>(),
st_in_ptr = in.data<T>(),
st_out_ptr = out.data<T>(),
encoder.dispatch(
[st_wt_ptr = wt.data<T>(),
st_in_ptr = in.data<T>(),
st_out_ptr = out.data<T>(),
N = in.shape(
0), // Batch size, should be the same as out.shape(0)
iH = 1 +
in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
iW = 1 +
in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
C = in.shape(3), // In channels
oH = out.shape(1), // Output spatial dim
oW = out.shape(2), // Output spatial dim
O = wt.shape(0), // Out channels
wH = wt.shape(1), // Weight spatial dim
wW = wt.shape(2), // Weight spatial dim
N = in.shape(0), // Batch size, should be the same as out.shape(0)
iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
C = in.shape(3), // In channels
oH = out.shape(1), // Output spatial dim
oW = out.shape(2), // Output spatial dim
O = wt.shape(0), // Out channels
wH = wt.shape(1), // Weight spatial dim
wW = wt.shape(2), // Weight spatial dim
groups = in.shape(3) / wt.shape(3),
C_per_group = wt.shape(3),
groups = in.shape(3) / wt.shape(3),
C_per_group = wt.shape(3),
in_stride_N = in.strides()[0],
in_stride_H = in.strides()[1],
in_stride_W = in.strides()[2],
in_stride_C = in.strides()[3],
in_stride_N = in.strides()[0],
in_stride_H = in.strides()[1],
in_stride_W = in.strides()[2],
in_stride_C = in.strides()[3],
wt_stride_O = wt.strides()[0],
wt_stride_H = wt.strides()[1],
wt_stride_W = wt.strides()[2],
wt_stride_C = wt.strides()[3],
wt_stride_O = wt.strides()[0],
wt_stride_H = wt.strides()[1],
wt_stride_W = wt.strides()[2],
wt_stride_C = wt.strides()[3],
out_stride_N = out.strides()[0],
out_stride_H = out.strides()[1],
out_stride_W = out.strides()[2],
out_stride_O = out.strides()[3],
out_stride_N = out.strides()[0],
out_stride_H = out.strides()[1],
out_stride_W = out.strides()[2],
out_stride_O = out.strides()[3],
padding,
wt_strides,
wt_dilation,
in_dilation,
flip]() mutable {
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip]() mutable {
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
const int O_per_group = O / groups;
auto pt_conv_no_checks = [&](const T* in_ptr,
const T* wt_ptr,
T* out_ptr,
int oh,
int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
const int O_per_group = O / groups;
auto pt_conv_no_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding_lo[0];
int iw_base = ow * wt_strides[1] - padding_lo[1];
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + ih * in_stride_H + iw * in_stride_W;
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
} // ww
} // wh
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
int f_wgt_jump_h =
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_w =
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_wgt_jump_h =
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_w =
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
int f_out_jump_h =
std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_w =
std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w);
std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h;
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
wh_base++;
ih_loop += jump_h;
}
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
wh_base++;
ih_loop += jump_h;
}
base_h[i] = wh_base;
}
base_h[i] = wh_base;
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w;
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
ww_base++;
iw_loop += jump_w;
}
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
ww_base++;
iw_loop += jump_w;
}
base_w[j] = ww_base;
}
base_w[j] = ww_base;
}
auto pt_conv_all_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
auto pt_conv_all_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
int ih_base = oh * wt_strides[0] - padding_lo[0];
int iw_base = ow * wt_strides[1] - padding_lo[1];
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
const T* in_ptr_pt =
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H +
iw_dil * in_stride_W;
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
} // ih, iw check
} // ww
} // wh
} // ih, iw check
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
int oH_border_0 = 0;
int oH_border_1 =
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
int oH_border_2 = std::max(
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
int oH_border_3 = oH;
int oH_border_0 = 0;
int oH_border_1 = is_idil_one
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
: oH;
int oH_border_2 = std::max(
oH_border_1,
(iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]);
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 =
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
int oW_border_2 = std::max(
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
int oW_border_3 = oW;
int oW_border_0 = 0;
int oW_border_1 = is_idil_one
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
: oW;
int oW_border_2 = std::max(
oW_border_1,
(iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
// Case 1: oh might put us out of bounds
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
for (int n = 0; n < N; ++n) {
// Case 1: oh might put us out of bounds
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
// Case 2: oh in bounds
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
// Case a: ow might put us out of bounds
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case 2: oh in bounds
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
// Case a: ow might put us out of bounds
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case b: ow in bounds
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case b: ow in bounds
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case c: ow might put us out of bounds
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case c: ow might put us out of bounds
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
} // oh
// Case 3: oh might put us out of bounds
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
// Case 3: oh might put us out of bounds
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
st_in_ptr += in_stride_N;
st_out_ptr += out_stride_N;
st_in_ptr += in_stride_N;
st_out_ptr += out_stride_N;
} // n
});
} // n
});
}
template <typename T>
@@ -351,7 +359,8 @@ void slow_conv_3D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -400,7 +409,8 @@ void slow_conv_3D(
out_stride_H = out.strides()[2],
out_stride_W = out.strides()[3],
out_stride_O = out.strides()[4],
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@@ -415,9 +425,9 @@ void slow_conv_3D(
int oh,
int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0];
int ih_base = oh * wt_strides[1] - padding[1];
int iw_base = ow * wt_strides[2] - padding[2];
int id_base = od * wt_strides[0] - padding_lo[0];
int ih_base = oh * wt_strides[1] - padding_lo[1];
int iw_base = ow * wt_strides[2] - padding_lo[2];
for (int o = 0; o < O; ++o) {
float r = 0.;
@@ -478,7 +488,7 @@ void slow_conv_3D(
std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_d; ++i) {
int id_loop = i * wt_strides[0] - padding[0] + init_d;
int id_loop = i * wt_strides[0] - padding_lo[0] + init_d;
int wd_base = 0;
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
@@ -490,7 +500,7 @@ void slow_conv_3D(
}
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h;
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
@@ -502,7 +512,7 @@ void slow_conv_3D(
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w;
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
@@ -521,9 +531,9 @@ void slow_conv_3D(
int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0];
int ih_base = oh * wt_strides[1] - padding[1];
int iw_base = ow * wt_strides[2] - padding[2];
int id_base = od * wt_strides[0] - padding_lo[0];
int ih_base = oh * wt_strides[1] - padding_lo[1];
int iw_base = ow * wt_strides[2] - padding_lo[2];
int wd_base = base_d[od % f_out_jump_d];
int wh_base = base_h[oh % f_out_jump_h];
@@ -573,24 +583,30 @@ void slow_conv_3D(
};
int oD_border_0 = 0;
int oD_border_1 =
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
int oD_border_1 = is_idil_one
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
: oD;
int oD_border_2 = std::max(
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
oD_border_1,
(iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]);
int oD_border_3 = oD;
int oH_border_0 = 0;
int oH_border_1 =
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
int oH_border_1 = is_idil_one
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
: oH;
int oH_border_2 = std::max(
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
oH_border_1,
(iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]);
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 =
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
int oW_border_1 = is_idil_one
? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2])
: oW;
int oW_border_2 = std::max(
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
oW_border_1,
(iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
@@ -658,7 +674,8 @@ void dispatch_slow_conv_1D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -669,7 +686,8 @@ void dispatch_slow_conv_1D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@@ -680,7 +698,8 @@ void dispatch_slow_conv_1D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@@ -691,7 +710,8 @@ void dispatch_slow_conv_1D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@@ -707,7 +727,8 @@ void dispatch_slow_conv_2D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -718,7 +739,8 @@ void dispatch_slow_conv_2D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@@ -729,7 +751,8 @@ void dispatch_slow_conv_2D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@@ -740,7 +763,8 @@ void dispatch_slow_conv_2D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@@ -756,7 +780,8 @@ void dispatch_slow_conv_3D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -767,7 +792,8 @@ void dispatch_slow_conv_3D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@@ -778,7 +804,8 @@ void dispatch_slow_conv_3D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@@ -789,7 +816,8 @@ void dispatch_slow_conv_3D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@@ -829,7 +857,8 @@ void explicit_gemm_conv_1D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
Stream stream) {
@@ -848,7 +877,7 @@ void explicit_gemm_conv_1D_cpu(
auto& encoder = cpu::get_command_encoder(stream);
// Pad input
Shape padded_shape = {N, iH + 2 * padding[0], C};
Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
@@ -857,7 +886,7 @@ void explicit_gemm_conv_1D_cpu(
copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding[0] * in_padded.strides()[1];
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
@@ -971,7 +1000,8 @@ void explicit_gemm_conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
Stream stream) {
@@ -989,7 +1019,11 @@ void explicit_gemm_conv_2D_cpu(
auto& encoder = cpu::get_command_encoder(stream);
// Pad input
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
Shape padded_shape = {
N,
iH + padding_lo[0] + padding_hi[0],
iW + padding_lo[1] + padding_hi[1],
C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
@@ -998,8 +1032,8 @@ void explicit_gemm_conv_2D_cpu(
copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset =
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2];
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
padding_lo[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
@@ -1091,7 +1125,8 @@ void explicit_gemm_conv_ND_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const bool flip,
@@ -1114,7 +1149,7 @@ void explicit_gemm_conv_ND_cpu(
Shape padded_shape(in.shape().size());
padded_shape.front() = N;
for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
}
padded_shape.back() = C;
array in_padded(padded_shape, conv_dtype, nullptr, {});
@@ -1125,9 +1160,10 @@ void explicit_gemm_conv_ND_cpu(
// Pick input slice from padded
size_t data_offset = 0;
for (size_t i = 0; i < padding.size(); i++) {
data_offset += padding[i] * in_padded.strides()[i + 1];
for (size_t i = 0; i < padding_lo.size(); i++) {
data_offset += padding_lo[i] * in_padded.strides()[i + 1];
}
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
@@ -1261,7 +1297,8 @@ void conv_1D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -1270,22 +1307,40 @@ void conv_1D_cpu(
const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
return explicit_gemm_conv_1D_cpu(
in, wt, out, padding, wt_strides, wt_dilation, stream);
in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream);
}
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
}
return dispatch_slow_conv_1D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
}
void conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -1295,18 +1350,35 @@ void conv_2D_cpu(
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
in_dilation[1] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
}
return dispatch_slow_conv_2D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
}
void conv_3D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -1317,11 +1389,28 @@ void conv_3D_cpu(
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
}
return dispatch_slow_conv_3D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
}
} // namespace
@@ -1338,7 +1427,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_,
padding_lo_,
padding_hi_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
@@ -1351,7 +1441,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_,
padding_lo_,
padding_hi_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
@@ -1364,7 +1455,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_,
padding_lo_,
padding_hi_,
kernel_strides_,
kernel_dilation_,
input_dilation_,

174
mlx/backend/cpu/eig.cpp Normal file
View File

@@ -0,0 +1,174 @@
// Copyright © 2025 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T>
void eig_impl(
array& a,
array& vectors,
array& values,
bool compute_eigenvectors,
Stream stream) {
using OT = std::complex<T>;
auto a_ptr = a.data<T>();
auto eig_ptr = values.data<OT>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(values);
OT* vec_ptr = nullptr;
if (compute_eigenvectors) {
encoder.set_output_array(vectors);
vec_ptr = vectors.data<OT>();
}
encoder.dispatch([a_ptr,
vec_ptr,
eig_ptr,
compute_eigenvectors,
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
char jobr = 'N';
char jobl = compute_eigenvectors ? 'V' : 'N';
int n_vecs_r = 1;
int n_vecs_l = compute_eigenvectors ? N : 1;
int lwork = -1;
int info;
{
T work;
int iwork;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
}
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
auto vec_tmp_data =
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
for (size_t i = 0; i < size / (N * N); ++i) {
geev<T>(
&jobl,
&jobr,
&N,
a_ptr,
&N,
eig_tmp,
eig_tmp + N,
vec_tmp,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
&info);
for (int i = 0; i < N; ++i) {
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
}
if (vec_ptr) {
for (int i = 0; i < N; ++i) {
if (eig_ptr[i].imag() != 0) {
// This vector and the next are a pair
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
vec_ptr[(i + 1) * N + j] = {
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
}
}
}
vec_ptr += N * N;
}
a_ptr += N * N;
eig_ptr += N;
if (info != 0) {
std::stringstream msg;
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
});
encoder.add_temporary(a);
}
} // namespace
void Eig::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];
auto vectors = compute_eigenvectors_
? outputs[1]
: array(a.shape(), complex64, nullptr, {});
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
copy(
a,
a_copy,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
values.set_data(allocator::malloc(values.nbytes()));
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
// are in the columns of the output
auto flags = vectors.flags();
auto strides = vectors.strides();
auto ndim = a.ndim();
std::swap(strides[ndim - 1], strides[ndim - 2]);
if (a.size() > 1) {
flags.row_contiguous = false;
if (ndim > 2) {
flags.col_contiguous = false;
} else {
flags.col_contiguous = true;
}
}
vectors.set_data(
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
}
switch (a.dtype()) {
case float32:
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
}
}
} // namespace mlx::core

View File

@@ -12,6 +12,133 @@ namespace mlx::core {
namespace {
template <typename T, class Enable = void>
struct EighWork {};
template <typename T>
struct EighWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using R = T;
char jobz;
char uplo;
int N;
int lwork;
int liwork;
int info;
std::vector<array::Data> buffers;
EighWork(char jobz_, char uplo_, int N_)
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
T work;
int iwork;
syevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work);
liwork = iwork;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
}
void run(T* vectors, T* values) {
syevd<T>(
&jobz,
&uplo,
&N,
vectors,
&N,
values,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<int*>(buffers[1].buffer.raw_ptr()),
&liwork,
&info);
}
};
template <>
struct EighWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
char jobz;
char uplo;
int N;
int lwork;
int lrwork;
int liwork;
int info;
std::vector<array::Data> buffers;
EighWork(char jobz_, char uplo_, int N_)
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
T work;
R rwork;
int iwork;
heevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&rwork,
&lrwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work.real());
lrwork = static_cast<int>(rwork);
liwork = iwork;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
}
void run(T* vectors, R* values) {
heevd<T>(
&jobz,
&uplo,
&N,
vectors,
&N,
values,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<R*>(buffers[1].buffer.raw_ptr()),
&lrwork,
static_cast<int*>(buffers[2].buffer.raw_ptr()),
&liwork,
&info);
if (jobz == 'V') {
// We have pre-transposed the vectors but we also must conjugate them
// when they are complex.
//
// We could vectorize this but it is so fast in comparison to heevd that
// it doesn't really matter.
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
*vectors = std::conj(*vectors);
vectors++;
}
}
}
}
};
template <typename T>
void eigh_impl(
array& vectors,
@@ -19,8 +146,10 @@ void eigh_impl(
const std::string& uplo,
bool compute_eigenvectors,
Stream stream) {
using R = typename EighWork<T>::R;
auto vec_ptr = vectors.data<T>();
auto eig_ptr = values.data<T>();
auto eig_ptr = values.data<R>();
char jobz = compute_eigenvectors ? 'V' : 'N';
auto& encoder = cpu::get_command_encoder(stream);
@@ -33,49 +162,17 @@ void eigh_impl(
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
int lwork = -1;
int liwork = -1;
int info;
{
T work;
int iwork;
syevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work);
liwork = iwork;
}
EighWork<T> work(jobz, uplo, N);
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
// Work loop
for (size_t i = 0; i < size / (N * N); ++i) {
syevd<T>(
&jobz,
&uplo,
&N,
vec_ptr,
&N,
eig_ptr,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
&liwork,
&info);
work.run(vec_ptr, eig_ptr);
vec_ptr += N * N;
eig_ptr += N;
if (info != 0) {
if (work.info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
<< work.info;
throw std::runtime_error(msg.str());
}
}
@@ -131,6 +228,10 @@ void Eigh::eval_cpu(
eigh_impl<double>(
vectors, values, uplo_, compute_eigenvectors_, stream());
break;
case complex64:
eigh_impl<std::complex<float>>(
vectors, values, uplo_, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error(
"[Eigh::eval_cpu] only supports float32 or float64.");

View File

@@ -257,15 +257,11 @@ void gather_axis(
const array& ind,
array& out,
const int axis) {
auto strides = ind.strides();
strides.erase(strides.begin() + axis);
auto shape = ind.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
strides = src.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
auto shape = remove_index(ind.shape(), axis);
ContiguousIterator ind_it(
shape, remove_index(ind.strides(), axis), src.ndim() - 1);
ContiguousIterator src_it(
shape, remove_index(src.strides(), axis), src.ndim() - 1);
auto ind_ptr = ind.data<IdxT>();
auto src_ptr = src.data<T>();
@@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
template <typename T, typename IdxT, typename OpT>
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
auto strides = idx.strides();
strides.erase(strides.begin() + axis);
auto shape = idx.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
strides = upd.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
auto shape = remove_index(idx.shape(), axis);
ContiguousIterator idx_it(
shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
ContiguousIterator upd_it(
shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
auto idx_ptr = idx.data<IdxT>();
auto upd_ptr = upd.data<T>();

View File

@@ -2,14 +2,14 @@
#pragma once
// Required for Visual Studio.
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
#ifdef _MSC_VER
#include <complex>
#define LAPACK_COMPLEX_CUSTOM
#define lapack_complex_float std::complex<float>
#define lapack_complex_double std::complex<double>
#endif
#define lapack_complex_float_real(z) ((z).real())
#define lapack_complex_float_imag(z) ((z).imag())
#define lapack_complex_double_real(z) ((z).real())
#define lapack_complex_double_imag(z) ((z).imag())
#ifdef MLX_USE_ACCELERATE
#include <Accelerate/Accelerate.h>
@@ -32,7 +32,7 @@
#endif
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
#define INSTANTIATE_LAPACK_REAL(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, float>) { \
@@ -42,11 +42,24 @@
} \
}
INSTANTIATE_LAPACK_TYPES(geqrf)
INSTANTIATE_LAPACK_TYPES(orgqr)
INSTANTIATE_LAPACK_TYPES(syevd)
INSTANTIATE_LAPACK_TYPES(potrf)
INSTANTIATE_LAPACK_TYPES(gesvdx)
INSTANTIATE_LAPACK_TYPES(getrf)
INSTANTIATE_LAPACK_TYPES(getri)
INSTANTIATE_LAPACK_TYPES(trtri)
INSTANTIATE_LAPACK_REAL(geqrf)
INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_REAL(gesvdx)
INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri)
#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, std::complex<float>>) { \
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
} \
}
INSTANTIATE_LAPACK_COMPLEX(heevd)

View File

@@ -132,6 +132,10 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
if (out.size() == 0) {
out.set_data(allocator::malloc(out.nbytes()));
return;
}
// Fill output with C
auto& c = inputs[2];
@@ -139,7 +143,9 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy(c, out, ctype, stream());
if (inputs[0].shape(-1) == 0) {
return;
}
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
}

View File

@@ -13,9 +13,18 @@ namespace mlx::core {
namespace {
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) {
auto power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T, int bits>
void extract_bits(const uint8_t* w_in, T* w_out) {
assert(bits == 3 || bits == 6);
static_assert(bits == 3 || bits == 5 || bits == 6);
if (bits == 3) {
w_out[0] = static_cast<T>(w_in[0] & 0x7);
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
@@ -25,6 +34,16 @@ void extract_bits(const uint8_t* w_in, T* w_out) {
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
} else if (bits == 5) {
w_out[0] = static_cast<T>(w_in[0] & 0x1f);
w_out[1] = static_cast<T>(((w_in[0] & 0xe0) >> 5) + ((w_in[1] & 0x3) << 3));
w_out[2] = static_cast<T>((w_in[1] & 0x7c) >> 2);
w_out[3] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0xf) << 1));
w_out[4] = static_cast<T>(((w_in[2] & 0xf0) >> 4) + ((w_in[3] & 0x1) << 4));
w_out[5] = static_cast<T>((w_in[3] & 0x3e) >> 1);
w_out[6] = static_cast<T>(((w_in[3] & 0xc0) >> 6) + ((w_in[4] & 0x7) << 2));
w_out[7] = static_cast<T>((w_in[4] & 0xf8) >> 3);
} else if (bits == 6) {
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
w_out[1] =
@@ -46,8 +65,8 @@ void _qmm(
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int pack_factor = get_pack_factor(bits, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
@@ -65,7 +84,7 @@ void _qmm(
T scale = *scales_local++;
T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) {
if (bits == 3 || bits == 6) {
if constexpr (bits == 3 || bits == 5 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
@@ -104,8 +123,9 @@ void _qmm_t(
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int pack_factor = get_pack_factor(bits, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
@@ -121,7 +141,7 @@ void _qmm_t(
T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) {
if (bits == 3 || bits == 6) {
if constexpr (bits == 3 || bits == 5 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
@@ -304,6 +324,10 @@ void _qmm_dispatch_typed(
_qmm_dispatch_group<T, 4>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 5:
_qmm_dispatch_group<T, 5>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 6:
_qmm_dispatch_group<T, 6>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
@@ -613,9 +637,8 @@ void quantize(
float eps = 1e-7;
bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
int bytes_per_pack = power_of_2_bits ? 1 : 3;
int el_per_int = get_pack_factor(bits, 32);
int bytes_per_pack = get_bytes_per_pack(bits);
int int_per_group = group_size * bytes_per_pack / el_per_int;
size_t n_groups = w_size / group_size;
@@ -640,15 +663,21 @@ void quantize(
}
size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
uint32_t out_el = 0;
uint64_t out_el = 0;
for (int k = 0; k < el_per_int; ++k) {
float w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - bias) / scale);
w_el = std::min(std::max(w_el, 0.0f), n_bins);
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
out_el |= static_cast<uint64_t>(w_el) << (k * bits);
}
if (power_of_2_bits) {
out[out_idx + j] = out_el;
} else if (bits == 5) {
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
out[out_idx + bytes_per_pack * j + 3] = (out_el & 0xff000000) >> 24;
out[out_idx + bytes_per_pack * j + 4] = (out_el & 0xff00000000) >> 32;
} else {
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;

View File

@@ -2,32 +2,13 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/utils.h"
namespace mlx::core {
void set_unary_output_data(const array& in, array& out) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
}
template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {

View File

@@ -6,17 +6,21 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
target_compile_definitions(mlx PUBLIC MLX_USE_CUDA)
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
# Enable defining device lambda functions.
target_compile_options(mlx
@@ -25,7 +29,7 @@ target_compile_options(mlx
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES
"75;80"
"70;80"
CACHE STRING "CUDA architectures")
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
@@ -36,7 +40,7 @@ FetchContent_Declare(
cccl
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
FetchContent_MakeAvailable(cccl)
target_include_directories(mlx PRIVATE BEFORE "${cccl_SOURCE_DIR}/include")
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
# Use fixed version of NVTX.
FetchContent_Declare(
@@ -52,6 +56,9 @@ target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
find_package(CUDAToolkit REQUIRED)
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
# Use cublasLt.
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)

View File

@@ -6,6 +6,7 @@
#include <cuda_runtime.h>
#include <fmt/format.h>
#include <unistd.h>
#include <cassert>
@@ -13,24 +14,50 @@ namespace mlx::core {
namespace cu {
CudaAllocator::CudaAllocator() {
CudaAllocator::CudaAllocator()
: buffer_cache_(
getpagesize(),
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) {
cuda_free(buf->data);
delete buf;
}) {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8;
max_pool_size_ = memory_limit_;
}
Buffer CudaAllocator::malloc(size_t size) {
// TODO: Check memory limit.
auto* buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(
fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
// Find available buffer from cache.
std::unique_lock lock(mutex_);
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache.
size_t mem_required = get_active_memory() + get_cache_memory() + size;
if (mem_required >= memory_limit_) {
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
}
lock.unlock();
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
lock.lock();
}
std::lock_guard lock(mutex_);
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
// Maintain the cache below the requested limit.
if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
return Buffer{buf};
}
@@ -40,26 +67,15 @@ void CudaAllocator::free(Buffer buffer) {
return;
}
// If free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([buffer]() { allocator().free(buffer); });
worker_->end_batch();
worker_->commit();
return;
}
std::unique_lock lock(mutex_);
active_memory_ -= buf->size;
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
lock.unlock();
cuda_free(buf->data);
delete buf;
}
size_t size = buf->size;
cudaFree(buf->data);
delete buf;
std::lock_guard lock(mutex_);
active_memory_ -= size;
}
size_t CudaAllocator::size(Buffer buffer) const {
@@ -75,6 +91,25 @@ void CudaAllocator::register_this_thread() {
allowed_threads_.insert(std::this_thread::get_id());
}
void CudaAllocator::cuda_free(void* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
cudaFree(buf);
}
size_t CudaAllocator::get_active_memory() const {
return active_memory_;
}
@@ -98,6 +133,21 @@ size_t CudaAllocator::set_memory_limit(size_t limit) {
return limit;
}
size_t CudaAllocator::get_cache_memory() const {
return buffer_cache_.cache_size();
}
size_t CudaAllocator::set_cache_limit(size_t limit) {
std::lock_guard lk(mutex_);
std::swap(limit, max_pool_size_);
return limit;
}
void CudaAllocator::clear_cache() {
std::lock_guard lk(mutex_);
buffer_cache_.clear();
}
CudaAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of CudaAllocator
// will not be called on exit and buffers in the cache will be leaked. This
@@ -138,17 +188,19 @@ size_t set_memory_limit(size_t limit) {
size_t get_memory_limit() {
return cu::allocator().get_memory_limit();
}
// TODO: Implement buffer cache.
size_t get_cache_memory() {
return 0;
return cu::allocator().get_cache_memory();
}
size_t set_cache_limit(size_t) {
return 0;
size_t set_cache_limit(size_t limit) {
return cu::allocator().set_cache_limit(limit);
}
void clear_cache() {
cu::allocator().clear_cache();
}
// Not supported in CUDA.
size_t set_wired_limit(size_t) {
return 0;
}
void clear_cache() {}
} // namespace mlx::core

View File

@@ -3,6 +3,7 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include <mutex>
#include <set>
@@ -33,11 +34,17 @@ class CudaAllocator : public allocator::Allocator {
// buffers there would result in dead lock.
void register_this_thread();
// Call cudaFree in the safe thread.
void cuda_free(void* buf);
size_t get_active_memory() const;
size_t get_peak_memory() const;
void reset_peak_memory();
size_t get_memory_limit();
size_t set_memory_limit(size_t limit);
size_t get_cache_memory() const;
size_t set_cache_limit(size_t limit);
void clear_cache();
private:
CudaAllocator();
@@ -49,6 +56,8 @@ class CudaAllocator : public allocator::Allocator {
std::mutex mutex_;
size_t memory_limit_;
size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
};

305
mlx/backend/cuda/binary.cu Normal file
View File

@@ -0,0 +1,305 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/binary_ops.cuh"
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[0], b[0]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[0], b[index]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[index], b[0]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[index], b[index]);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_g_nd(
const In* a,
const In* b,
Out* out,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
out[index] = Op{}(a[a_idx], b[b_idx]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g(
const In* a,
const In* b,
Out* out,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
out[index] = Op{}(a[a_idx], b[b_idx]);
}
}
template <typename Op, typename In, typename Out>
constexpr bool supports_binary_op() {
if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||
std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||
std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||
std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {
return std::is_same_v<In, Out>;
}
if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||
std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||
std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {
return std::is_same_v<Out, bool>;
}
if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, NaNEqual>) {
return std::is_same_v<Out, bool> &&
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
}
if (std::is_same_v<Op, LogAddExp> || std::is_same_v<Op, ArcTan2>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
std::is_same_v<Op, BitwiseXor>) {
return std::is_same_v<In, Out> && std::is_integral_v<In>;
}
if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
}
return false;
}
} // namespace cu
template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
auto& out = outputs[0];
if (out.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
bool large = a.data_size() > UINT32_MAX ||
b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel =
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size(),
const_param<NDIM>(shape),
const_param<NDIM>(a_strides),
const_param<NDIM>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
} else {
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size());
});
}
} else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out.dtype())));
}
});
});
});
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
std::vector<array> outputs{out};
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
}
#define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = outputs[0].primitive().stream(); \
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
}
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Equal)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, op, s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, op, s);
break;
}
}
} // namespace mlx::core

View File

@@ -34,14 +34,26 @@ CommandEncoder& DeviceStream::get_encoder() {
}
Device::Device(int device) : device_(device) {
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
&compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_));
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
&compute_capability_minor_, cudaDevAttrComputeCapabilityMinor, device_));
// Validate the requirements of device.
int attr = 0;
cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_);
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
&attr, cudaDevAttrConcurrentManagedAccess, device_));
if (attr != 1) {
throw std::runtime_error(fmt::format(
"Device {} does not support synchronization in managed memory.",
device_));
}
// The cublasLt handle is used by matmul.
make_current();
cublasLtCreate(&lt_);
}
Device::~Device() {
cublasLtDestroy(lt_);
}
void Device::make_current() {

View File

@@ -6,6 +6,7 @@
#include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h"
#include <cublasLt.h>
#include <thrust/execution_policy.h>
#include <unordered_map>
@@ -46,6 +47,7 @@ class DeviceStream {
class Device {
public:
explicit Device(int device);
~Device();
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
@@ -58,9 +60,21 @@ class Device {
int cuda_device() const {
return device_;
}
int compute_capability_major() const {
return compute_capability_major_;
}
int compute_capability_minor() const {
return compute_capability_minor_;
}
cublasLtHandle_t lt_handle() const {
return lt_;
}
private:
int device_;
int compute_capability_major_;
int compute_capability_minor_;
cublasLtHandle_t lt_;
std::unordered_map<int, DeviceStream> streams_;
};

View File

@@ -1,35 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace mlx::core {
// Maps CPU types to CUDA types.
template <typename T>
struct CTypeToCudaType {
using type = T;
};
template <>
struct CTypeToCudaType<float16_t> {
using type = __half;
};
template <>
struct CTypeToCudaType<bfloat16_t> {
using type = __nv_bfloat16;
};
template <>
struct CTypeToCudaType<complex64_t> {
using type = cuComplex;
};
template <typename T>
using cuda_type_t = typename CTypeToCudaType<T>::type;
} // namespace mlx::core

View File

@@ -1,5 +1,6 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
@@ -111,12 +112,12 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
SharedEvent::SharedEvent() {
// Allocate cuda::atomic on managed memory.
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic));
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr());
Atomic* ac;
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
new (ac) Atomic(0);
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) {
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
ptr->~Atomic();
allocator::free(buffer);
allocator().cuda_free(ptr);
});
}
@@ -155,7 +156,10 @@ void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
void SharedEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { signal(value); });
// Signal through a GPU stream so the atomic is updated in GPU - updating
// the atomic in CPU sometimes does not get GPU notified.
static CudaStream stream(device(mlx::core::Device::gpu));
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
} else {
auto& encoder = get_command_encoder(s);
encoder.launch_kernel(

View File

@@ -0,0 +1,29 @@
// Copyright © 2025 Apple Inc.
#include "mlx/fence.h"
#include "mlx/backend/cuda/event.h"
namespace mlx::core {
struct FenceImpl {
uint32_t count;
cu::SharedEvent event;
};
Fence::Fence(Stream s) {
fence_ = std::shared_ptr<void>(
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
}
void Fence::wait(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
fence->event.wait(fence->count);
}
void Fence::update(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
fence->count++;
fence->event.signal(s, fence->count);
}
} // namespace mlx::core

View File

@@ -1,70 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/fence.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace {
__host__ __device__ void busy_wait(cuda::atomic<uint64_t>* ac, uint64_t value) {
while (true) {
// In theory the atomic_thread_fence is not needed, but for CUDA 11 without
// it the load() may never return new value.
cuda::atomic_thread_fence(cuda::memory_order_seq_cst);
uint64_t current = ac->load();
if (current >= value) {
break;
}
}
}
__global__ void busy_wait_kernel(cuda::atomic<uint64_t>* ac, uint64_t value) {
busy_wait(ac, value);
}
} // namespace
struct FenceImpl {
uint32_t count;
cu::SharedEvent event;
};
Fence::Fence(Stream s) {
fence_ = std::shared_ptr<void>(
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
}
void Fence::wait(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
// We can't use SharedEvent::wait because it could hang in CUDA 11, see also:
// https://github.com/ml-explore/mlx/issues/2137
const auto& ac = fence->event.atomic();
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [ac, count = fence->count]() {
nvtx3::scoped_range r("Fence::wait()");
busy_wait(ac.get(), count);
});
} else {
nvtx3::scoped_range r("Fence::wait(s)");
auto& encoder = cu::get_command_encoder(s);
encoder.launch_kernel(
encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) {
busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count);
});
encoder.add_completed_handler([ac]() {});
encoder.end_encoding();
}
}
void Fence::update(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
fence->count++;
fence->event.signal(s, fence->count);
}
} // namespace mlx::core

View File

@@ -0,0 +1,121 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <thrust/iterator/iterator_adaptor.h>
#include <cuda/std/utility>
#include "mlx/backend/cuda/kernel_utils.cuh"
namespace mlx::core::cu {
// Iterating non-contiguous array.
template <typename Iterator, typename IdxT = int64_t>
class general_iterator
: public thrust::
iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator> {
public:
using super_t =
thrust::iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator>;
using reference = typename super_t::reference;
using difference_type = typename super_t::difference_type;
__host__ __device__ general_iterator(
Iterator it,
IdxT index,
int ndim,
Shape shape,
Strides strides)
: super_t(it),
index_(index),
ndim_(ndim),
shape_(cuda::std::move(shape)),
strides_(cuda::std::move(strides)) {}
__host__ __device__ IdxT index() const {
return index_;
}
__host__ __device__ const Shape& shape() const {
return shape_;
}
__host__ __device__ const Strides& strides() const {
return strides_;
}
private:
friend class thrust::iterator_core_access;
__host__ __device__ bool equal(const general_iterator& other) const {
return this->base() == other.base() && this->index() == other.index();
}
__host__ __device__ void advance(difference_type n) {
this->index_ += n;
}
__host__ __device__ void increment() {
this->index_ += 1;
}
__host__ __device__ void decrement() {
this->index_ -= 1;
}
__host__ __device__ difference_type
distance_to(const general_iterator& other) const {
_CCCL_ASSERT(
this->base() == other.base(),
"Underlying iterator must point to same base iterator");
return other.index() - this->index();
}
// The dereference is device-only to avoid accidental running in host.
__device__ typename super_t::reference dereference() const {
IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_);
return *(this->base() + offset);
}
IdxT index_;
int ndim_;
Shape shape_;
Strides strides_;
};
template <typename IdxT, typename Iterator>
__host__ __device__ auto make_general_iterator(
Iterator it,
IdxT index,
int ndim,
Shape shape,
Strides strides) {
return general_iterator<Iterator, IdxT>(
it, index, ndim, cuda::std::move(shape), cuda::std::move(strides));
}
template <typename IdxT, typename Iterator>
auto make_general_iterator(
Iterator it,
const std::vector<int32_t>& shape,
const std::vector<int64_t>& strides) {
return make_general_iterator<IdxT>(
it, 0, shape.size(), const_param(shape), const_param(strides));
}
template <typename IdxT, typename Iterator>
auto make_general_iterators(
Iterator it,
IdxT size,
const std::vector<int32_t>& shape,
const std::vector<int64_t>& strides) {
auto ndim = shape.size();
auto shape_arg = const_param(shape);
auto strides_arg = const_param(strides);
return std::make_pair(
make_general_iterator<IdxT>(it, 0, ndim, shape_arg, strides_arg),
make_general_iterator<IdxT>(it, size, ndim, shape_arg, strides_arg));
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,26 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
namespace mlx::core {
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2) {
Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2);
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
}
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) {
Dims dims = get_2d_grid_dims_common(shape, strides);
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
}
dim3 get_2d_grid_dims(
const Shape& shape,
const Strides& strides,
size_t divisor) {
Dims dims = get_2d_grid_dims_common(shape, strides, divisor);
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
}
} // namespace mlx::core

View File

@@ -0,0 +1,131 @@
// Copyright © 2025 Apple Inc.
// This file includes host-only utilies for writing CUDA kernels, the difference
// from backend/cuda/kernels/utils.cuh is that the latter file only include
// device-only code.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/kernels/utils.cuh"
#include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <fmt/format.h>
#include <cuda/cmath>
namespace mlx::core {
// Convert a number between 1~3 to constexpr.
#define MLX_SWITCH_1_2_3(N, NDIM, ...) \
switch (N) { \
case 1: { \
constexpr int NDIM = 1; \
__VA_ARGS__; \
break; \
} \
case 2: { \
constexpr int NDIM = 2; \
__VA_ARGS__; \
break; \
} \
case 3: { \
constexpr int NDIM = 3; \
__VA_ARGS__; \
break; \
} \
}
// Like MLX_SWITCH_ALL_TYPES but for booleans.
#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \
if (BOOL) { \
constexpr bool BOOL_ALIAS = true; \
__VA_ARGS__; \
} else { \
constexpr bool BOOL_ALIAS = false; \
__VA_ARGS__; \
}
// Maps CPU types to CUDA types.
template <typename T>
struct CTypeToCudaType {
using type = T;
};
template <>
struct CTypeToCudaType<float16_t> {
using type = __half;
};
template <>
struct CTypeToCudaType<bfloat16_t> {
using type = __nv_bfloat16;
};
template <>
struct CTypeToCudaType<complex64_t> {
using type = cuComplex;
};
template <typename T>
using cuda_type_t = typename CTypeToCudaType<T>::type;
// Type traits for detecting floating numbers.
template <typename T>
inline constexpr bool is_floating_v =
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
// Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t>
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM));
}
cuda::std::array<T, NDIM> result;
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
// Compute the grid and block dimensions, check backend/common/utils.h for docs.
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);
dim3 get_2d_grid_dims(
const Shape& shape,
const Strides& strides,
size_t divisor);
// Return a block size that achieves maximum potential occupancy for kernel.
template <typename T>
inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim;
CHECK_CUDA_ERROR(cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
return block_dim;
}
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
// assuming each thread handles |work_per_thread| elements of |arr|.
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
const array& arr,
bool large,
int work_per_thread = 1) {
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
uint block_dim = max_occupancy_block_dim(kernel);
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread);
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
} else {
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
}
return std::make_tuple(num_blocks, block_dim);
}
} // namespace mlx::core

View File

@@ -0,0 +1,278 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
#include <cuComplex.h>
#include <cuda/std/array>
namespace mlx::core::cu {
struct Add {
template <typename T>
__device__ T operator()(T x, T y) {
return x + y;
}
};
struct FloorDivide {
template <typename T>
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
return x / y;
} else {
return trunc(x / y);
}
}
};
struct Divide {
template <typename T>
__device__ T operator()(T x, T y) {
return x / y;
}
};
struct Remainder {
template <typename T>
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
if constexpr (cuda::std::is_signed_v<T>) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
} else {
return x % y;
}
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return x % y;
} else {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r = r + y;
}
return r;
}
}
};
struct Equal {
template <typename T>
__device__ bool operator()(T x, T y) {
return x == y;
}
};
struct NaNEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
if constexpr (std::is_same_v<T, cuComplex>) {
return x == y ||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) &&
isnan(cuCimagf(y))) ||
(cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) &&
isnan(cuCimagf(y))) ||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) &&
cuCimagf(x) == cuCimagf(y));
} else {
return x == y || (isnan(x) && isnan(y));
}
}
};
struct Greater {
template <typename T>
__device__ bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
__device__ bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
return x <= y;
}
};
struct LogAddExp {
template <typename T>
__device__ T operator()(T x, T y) {
if (isnan(x) || isnan(y)) {
return cuda::std::numeric_limits<T>::quiet_NaN();
}
T maxval = max(x, y);
T minval = min(x, y);
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
maxval == cuda::std::numeric_limits<T>::infinity())
? maxval
: T(float(maxval) + log1p(expf(minval - maxval)));
};
};
struct Maximum {
template <typename T>
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
return max(x, y);
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
return x;
}
return x > y ? x : y;
} else {
if (isnan(x)) {
return x;
}
return x > y ? x : y;
}
}
};
struct Minimum {
template <typename T>
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
return min(x, y);
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
return x;
}
return x < y ? x : y;
} else {
if (isnan(x)) {
return x;
}
return x < y ? x : y;
}
}
};
struct Multiply {
template <typename T>
__device__ T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
if constexpr (std::is_same_v<T, cuComplex>) {
return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y);
} else {
return x != y;
}
}
};
struct Power {
template <typename T>
__device__ T operator()(T base, T exp) {
if constexpr (cuda::std::is_integral_v<T>) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto x_theta = atan2f(base.y, base.x);
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
auto phase = exp.y * x_ln_r + exp.x * x_theta;
return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase));
} else {
return powf(base, exp);
}
}
};
struct Subtract {
template <typename T>
__device__ T operator()(T x, T y) {
return x - y;
}
};
struct LogicalAnd {
template <typename T>
__device__ T operator()(T x, T y) {
return x && y;
};
};
struct LogicalOr {
template <typename T>
__device__ T operator()(T x, T y) {
return x || y;
};
};
struct BitwiseAnd {
template <typename T>
__device__ T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
__device__ T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
__device__ T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
__device__ T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
__device__ T operator()(T x, T y) {
return x >> y;
};
};
struct ArcTan2 {
template <typename T>
__device__ T operator()(T y, T x) {
return atan2f(y, x);
}
};
struct DivMod {
template <typename T>
__device__ cuda::std::array<T, 2> operator()(T x, T y) {
return {FloorDivide{}(x, y), Remainder{}(x, y)};
};
};
} // namespace mlx::core::cu

View File

@@ -0,0 +1,240 @@
// Copyright © 2025 Apple Inc.
// Copyright © 2017-2024 The Simons Foundation, Inc.
//
// FINUFFT is licensed under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance with the
// License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Forked from
// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h
#pragma once
#include <cuComplex.h>
// This header provides some helper functions for cuComplex types.
// It mainly wraps existing CUDA implementations to provide operator overloads
// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are
// all provided by CUDA
__forceinline__ __host__ __device__ cuDoubleComplex
operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCadd(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCsub(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCmul(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCdiv(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) {
double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b));
double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b));
return make_cuDoubleComplex(r, i);
}
__forceinline__ __host__ __device__ bool operator==(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b);
}
__forceinline__ __host__ __device__ bool operator!=(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return !(a == b);
}
__forceinline__ __host__ __device__ bool operator>(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a));
double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b));
return mag_a > mag_b;
}
__forceinline__ __host__ __device__ bool operator>=(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return a > b || a == b;
}
__forceinline__ __host__ __device__ bool operator<(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return b > a;
}
__forceinline__ __host__ __device__ bool operator<=(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return b > a || a == b;
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator+(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator+(double a, const cuDoubleComplex& b) {
return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator-(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator-(double a, const cuDoubleComplex& b) {
return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator*(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator*(double a, const cuDoubleComplex& b) {
return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator/(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator/(double a, const cuDoubleComplex& b) {
double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b);
return make_cuDoubleComplex(
(a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator+(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCaddf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator-(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCsubf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator*(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCmulf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator/(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCdivf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator%(const cuFloatComplex& a, const cuFloatComplex& b) {
float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b));
float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b));
return make_cuFloatComplex(r, i);
}
__forceinline__ __host__ __device__ bool operator==(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b);
}
__forceinline__ __host__ __device__ bool operator!=(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return !(a == b);
}
__forceinline__ __host__ __device__ bool operator>(
const cuFloatComplex& a,
const cuFloatComplex& b) {
float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a));
float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b));
return mag_a > mag_b;
}
__forceinline__ __host__ __device__ bool operator>=(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return a > b || a == b;
}
__forceinline__ __host__ __device__ bool operator<(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return b > a;
}
__forceinline__ __host__ __device__ bool operator<=(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return b > a || a == b;
}
__forceinline__ __host__ __device__ cuFloatComplex
operator+(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator+(float a, const cuFloatComplex& b) {
return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator-(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator-(float a, const cuFloatComplex& b) {
return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator*(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator*(float a, const cuFloatComplex& b) {
return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator/(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator/(float a, const cuFloatComplex& b) {
float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b);
return make_cuFloatComplex(
(a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom);
}

View File

@@ -2,6 +2,7 @@
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda/std/limits>
#include <cuda/std/type_traits>
@@ -9,36 +10,122 @@
namespace mlx::core::cu {
///////////////////////////////////////////////////////////////////////////////
// Missing C++ operator overrides for CUDA 7.
// Unary ops for half types.
///////////////////////////////////////////////////////////////////////////////
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
template <typename T> \
__forceinline__ __device__ auto NAME(T x) { \
if constexpr (cuda::std::is_same_v<T, __half>) { \
return HALF_OP(x); \
} else { \
return ::NAME(x); \
} \
}
#else
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
template <typename T> \
__forceinline__ __device__ auto NAME(T x) { \
if constexpr (cuda::std::is_same_v<T, __half>) { \
return HALF_OP(x); \
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
return HALF_OP(x); \
} else { \
return ::NAME(x); \
} \
}
#endif
#define MLX_DEFINE_BF16_OP(OP) \
__forceinline__ __device__ __nv_bfloat16 operator OP( \
__nv_bfloat16 x, __nv_bfloat16 y) { \
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
#define MLX_DEFINE_UNARY_OP_FALLBCK(NAME) \
template <typename T> \
__forceinline__ __device__ auto NAME(T x) { \
if constexpr (cuda::std::is_same_v<T, __half>) { \
return ::NAME(__half2float(x)); \
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
return ::NAME(__bfloat162float(x)); \
} else { \
return ::NAME(x); \
} \
}
#define MLX_DEFINE_BF16_CMP(OP) \
__forceinline__ __device__ bool operator OP( \
__nv_bfloat16 x, __nv_bfloat16 y) { \
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
MLX_DEFINE_UNARY_OP(abs, __habs)
MLX_DEFINE_UNARY_OP(ceil, hceil)
MLX_DEFINE_UNARY_OP(cos, hcos)
MLX_DEFINE_UNARY_OP(exp, hexp)
MLX_DEFINE_UNARY_OP(floor, hfloor)
MLX_DEFINE_UNARY_OP(isnan, __hisnan)
MLX_DEFINE_UNARY_OP(log, hlog)
MLX_DEFINE_UNARY_OP(log2, hlog2)
MLX_DEFINE_UNARY_OP(log10, hlog10)
MLX_DEFINE_UNARY_OP(rint, hrint)
MLX_DEFINE_UNARY_OP(rsqrt, hrsqrt)
MLX_DEFINE_UNARY_OP(sin, hsin)
MLX_DEFINE_UNARY_OP(sqrt, hsqrt)
MLX_DEFINE_UNARY_OP_FALLBCK(acos)
MLX_DEFINE_UNARY_OP_FALLBCK(acosh)
MLX_DEFINE_UNARY_OP_FALLBCK(asin)
MLX_DEFINE_UNARY_OP_FALLBCK(asinh)
MLX_DEFINE_UNARY_OP_FALLBCK(atan)
MLX_DEFINE_UNARY_OP_FALLBCK(atanh)
MLX_DEFINE_UNARY_OP_FALLBCK(cosh)
MLX_DEFINE_UNARY_OP_FALLBCK(log1p)
MLX_DEFINE_UNARY_OP_FALLBCK(sinh)
MLX_DEFINE_UNARY_OP_FALLBCK(tan)
#if __CUDA_ARCH__ >= 1280
MLX_DEFINE_UNARY_OP(tanh, htanh)
#else
MLX_DEFINE_UNARY_OP_FALLBCK(tanh)
#endif
#undef MLX_DEFINE_UNARY_OP
#undef MLX_DEFINE_UNARY_OP_FALLBCK
///////////////////////////////////////////////////////////////////////////////
// Binary ops for half types.
///////////////////////////////////////////////////////////////////////////////
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
template <typename T> \
__forceinline__ __device__ auto NAME(T x, T y) { \
if constexpr (cuda::std::is_same_v<T, __half>) { \
return HALF_OP(x, y); \
} else { \
return ::NAME(x, y); \
} \
}
#else
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
template <typename T> \
__forceinline__ __device__ auto NAME(T x, T y) { \
if constexpr (cuda::std::is_same_v<T, __half>) { \
return HALF_OP(x, y); \
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
return HALF_OP(x, y); \
} else { \
return ::NAME(x, y); \
} \
}
#endif
MLX_DEFINE_BF16_OP(+)
MLX_DEFINE_BF16_OP(-)
MLX_DEFINE_BF16_OP(*)
MLX_DEFINE_BF16_OP(/)
MLX_DEFINE_BF16_CMP(>)
MLX_DEFINE_BF16_CMP(<)
MLX_DEFINE_BF16_CMP(>=)
MLX_DEFINE_BF16_CMP(<=)
MLX_DEFINE_BINARY_OP(max, __hmax)
MLX_DEFINE_BINARY_OP(min, __hmin)
#undef MLX_DEFINE_BF16_OP
#undef MLX_DEFINE_BF16_CMP
#undef MLX_DEFINE_BINARY_OP
#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
template <typename T>
__forceinline__ __device__ T fmod(T x, T y) {
if constexpr (cuda::std::is_same_v<T, __half>) {
return __float2half(::fmod(__half2float(x), __half2float(y)));
#if CUDART_VERSION >= 12000 || __CUDA_ARCH__ >= 800
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
return __float2bfloat16(::fmod(__bfloat162float(x), __bfloat162float(y)));
#endif
} else {
return ::fmod(x, y);
}
}
///////////////////////////////////////////////////////////////////////////////
// Additional C++ operator overrides between half types and native types.

View File

@@ -0,0 +1,349 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
#include "mlx/backend/cuda/kernels/utils.cuh"
namespace mlx::core::cu {
struct Abs {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_unsigned_v<T>) {
return x;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0};
} else {
return abs(x);
}
}
};
struct ArcCos {
template <typename T>
__device__ T operator()(T x) {
return acos(x);
}
};
struct ArcCosh {
template <typename T>
__device__ T operator()(T x) {
return acosh(x);
}
};
struct ArcSin {
template <typename T>
__device__ T operator()(T x) {
return asin(x);
}
};
struct ArcSinh {
template <typename T>
__device__ T operator()(T x) {
return asinh(x);
}
};
struct ArcTan {
template <typename T>
__device__ T operator()(T x) {
return atan(x);
}
};
struct ArcTanh {
template <typename T>
__device__ T operator()(T x) {
return atanh(x);
}
};
struct BitwiseInvert {
template <typename T>
__device__ T operator()(T x) {
return ~x;
}
};
struct Ceil {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_integral_v<T>) {
return x;
} else {
return ceil(x);
}
}
};
struct Conjugate {
__device__ cuComplex operator()(cuComplex x) {
return {cuCrealf(x), -cuCimagf(x)};
}
};
struct Cos {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
cos(cuCrealf(x)) * cosh(cuCimagf(x)),
-sin(cuCrealf(x)) * sinh(cuCimagf(x))};
} else {
return cos(x);
}
}
};
struct Cosh {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
cosh(cuCrealf(x)) * cos(cuCimagf(x)),
sinh(cuCrealf(x)) * sin(cuCimagf(x))};
} else {
return cosh(x);
}
}
};
struct Erf {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, __half>) {
return erf(__half2float(x));
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
return erf(__bfloat162float(x));
} else {
return erf(x);
}
}
};
struct ErfInv {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, __half>) {
return erfinv(__half2float(x));
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
return erfinv(__bfloat162float(x));
} else {
return erfinv(x);
}
}
};
struct Exp {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto m = exp(cuCrealf(x));
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
} else {
return exp(x);
}
}
};
struct Expm1 {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, __half>) {
return expm1(__half2float(x));
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
return expm1(__bfloat162float(x));
} else {
return expm1(x);
}
}
};
struct Floor {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_integral_v<T>) {
return x;
} else {
return floor(x);
}
}
};
struct Imag {
__device__ float operator()(cuComplex x) {
return cuCimagf(x);
}
};
struct Log {
template <typename T>
__device__ T operator()(T x) {
return log(x);
}
};
struct Log2 {
template <typename T>
__device__ T operator()(T x) {
return log2(x);
}
};
struct Log10 {
template <typename T>
__device__ T operator()(T x) {
return log10(x);
}
};
struct Log1p {
template <typename T>
__device__ T operator()(T x) {
return log1p(x);
}
};
struct LogicalNot {
__device__ bool operator()(bool x) {
return !x;
}
};
struct Negative {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return 0 - x;
} else {
return -x;
}
}
};
struct Real {
__device__ float operator()(cuComplex x) {
return cuCrealf(x);
}
};
struct Round {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {rint(cuCrealf(x)), rint(cuCimagf(x))};
} else {
return rint(x);
}
}
};
struct Rsqrt {
template <typename T>
__device__ T operator()(T x) {
return rsqrt(x);
}
};
struct Sigmoid {
template <typename T>
__device__ T operator()(T x) {
T y = 1 / (1 + exp(-abs(x)));
return (x < 0) ? 1 - y : y;
}
};
struct Sign {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_unsigned_v<T>) {
return x != 0;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (cuCrealf(x) == 0 && cuCimagf(x) == 0) {
return x;
} else {
return x / Abs()(x);
}
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
return static_cast<float>((x > T(0.f)) - (x < T(0.f)));
} else {
return (x > T(0)) - (x < T(0));
}
}
};
struct Sin {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
sin(cuCrealf(x)) * cosh(cuCimagf(x)),
cos(cuCrealf(x)) * sinh(cuCimagf(x))};
} else {
return sin(x);
}
}
};
struct Sinh {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
sinh(cuCrealf(x)) * cos(cuCimagf(x)),
cosh(cuCrealf(x)) * sin(cuCimagf(x))};
} else {
return sinh(x);
}
}
};
struct Square {
template <typename T>
__device__ T operator()(T x) {
return x * x;
}
};
struct Sqrt {
template <typename T>
__device__ T operator()(T x) {
return sqrt(x);
}
};
struct Tan {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
float tan_a = tan(cuCrealf(x));
float tanh_b = tanh(cuCimagf(x));
float t1 = tan_a * tanh_b;
float denom = 1. + t1 * t1;
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
} else {
return tan(x);
}
}
};
struct Tanh {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
float tanh_a = tanh(cuCrealf(x));
float tan_b = tan(cuCimagf(x));
float t1 = tanh_a * tan_b;
float denom = 1. + t1 * t1;
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
} else {
return tanh(x);
}
}
};
} // namespace mlx::core::cu

View File

@@ -0,0 +1,104 @@
// Copyright © 2025 Apple Inc.
// This file must not include any host-only code, utilies that work under both
// host and device can be put here.
//
// See more about the requirements at:
// https://docs.nvidia.com/cuda/nvrtc/#language
#pragma once
#include <cuComplex.h>
#include <cuda/std/array>
#include <cuda/std/limits>
#include <cuda/std/tuple>
namespace mlx::core::cu {
///////////////////////////////////////////////////////////////////////////////
// CUDA kernel utils
///////////////////////////////////////////////////////////////////////////////
// To pass shape/strides to kernels via constant memory, their size must be
// known at compile time.
#define MAX_NDIM 8
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
///////////////////////////////////////////////////////////////////////////////
// Indexing utils
///////////////////////////////////////////////////////////////////////////////
template <typename IdxT = int64_t>
inline __host__ __device__ IdxT
elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
IdxT loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
// Optimize when the ndim is known at compile time.
template <int NDIM, typename IdxT = int64_t>
inline __host__ __device__ IdxT
elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) {
IdxT loc = 0;
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
template <int NDIM, typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
IdxT elem,
const int* shape,
const int64_t* a_strides,
const int64_t* b_strides) {
IdxT a_loc = 0;
IdxT b_loc = 0;
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc);
}
// Optimized version when ndim is larger than 4.
template <typename IdxT = int64_t>
inline __host__ __device__ IdxT
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
IdxT loc = elem_to_loc_nd<3>(elem, shape, strides);
for (int i = ndim - 1; i >= 3; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
template <typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
IdxT elem,
const int* shape,
const int64_t* a_strides,
const int64_t* b_strides,
int ndim) {
auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides);
for (int i = ndim - 1; i >= 3; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc);
}
} // namespace mlx::core::cu

474
mlx/backend/cuda/matmul.cpp Normal file
View File

@@ -0,0 +1,474 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cublasLt.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
#include <numeric>
namespace mlx::core {
namespace cu {
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
void check_cublas_error(const char* name, cublasStatus_t err) {
if (err != CUBLAS_STATUS_SUCCESS) {
// TODO: Use cublasGetStatusString when it is widely available.
throw std::runtime_error(
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
}
}
class MatMul {
public:
MatMul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto type = dtype_to_cuda_type(dtype);
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), type));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(int32_t)));
cublasOperation_t op = CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&op,
sizeof(cublasOperation_t)));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB,
&op,
sizeof(cublasOperation_t)));
a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
b_desc_ = create_matrix_layout(
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
MatMul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
bool c_transposed,
int64_t ldc,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride,
int64_t c_batch_stride)
: MatMul(
device,
dtype,
a_transposed,
a_rows,
a_cols,
lda,
b_transposed,
b_rows,
b_cols,
ldb,
batch_count,
a_batch_stride,
b_batch_stride) {
auto type = dtype_to_cuda_type(dtype);
c_desc_ = create_matrix_layout(
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
}
~MatMul() {
cublasLtMatrixLayoutDestroy(a_desc_);
cublasLtMatrixLayoutDestroy(b_desc_);
cublasLtMatrixLayoutDestroy(c_desc_);
cublasLtMatrixLayoutDestroy(out_desc_);
cublasLtMatmulDescDestroy(matmul_desc_);
}
void run(
cu::CommandEncoder& encoder,
void* out,
void* a,
void* b,
void* c = nullptr,
float alpha = 1,
float beta = 0) {
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
int ret = 0;
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
encoder.device().lt_handle(),
matmul_desc_,
a_desc_,
b_desc_,
out_desc_,
out_desc_,
pref_,
1,
&heuristic_,
&ret));
if (ret == 0) {
throw std::runtime_error("Can not find algorithm for matmul.");
}
}
array workspace(
allocator::malloc(heuristic_.workspaceSize),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
encoder.launch_kernel([&](cudaStream_t stream) {
CHECK_CUBLAS_ERROR(cublasLtMatmul(
encoder.device().lt_handle(),
matmul_desc_,
&alpha,
a,
a_desc_,
b,
b_desc_,
&beta,
c ? c : out,
c ? c_desc_ : out_desc_,
out,
out_desc_,
&heuristic_.algo,
workspace.data<void>(),
workspace.nbytes(),
stream));
});
}
private:
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) {
case uint8:
case uint16:
case int8:
case int16:
case int32:
return CUBLAS_COMPUTE_32I;
case float16:
case bfloat16:
return CUBLAS_COMPUTE_16F;
case float32:
return CUBLAS_COMPUTE_32F;
case float64:
case complex64:
return CUBLAS_COMPUTE_64F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
}
}
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
switch (dtype) {
case uint8:
return CUDA_R_8U;
case uint16:
return CUDA_R_16U;
case int8:
return CUDA_R_8I;
case int16:
return CUDA_R_16I;
case int32:
return CUDA_R_32I;
case float16:
return CUDA_R_16F;
case bfloat16:
return CUDA_R_16BF;
case float32:
return CUDA_R_32F;
case float64:
return CUDA_R_64F;
case complex64:
return CUDA_C_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
}
}
cublasLtMatrixLayout_t create_matrix_layout(
cudaDataType_t type,
uint64_t rows,
uint64_t cols,
bool transposed,
int64_t ld,
int32_t batch_count,
int64_t batch_stride) {
cublasLtMatrixLayout_t desc;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order =
transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch_count,
sizeof(int32_t)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&batch_stride,
sizeof(int64_t)));
}
return desc;
}
cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr};
cublasLtMatrixLayout_t b_desc_{nullptr};
cublasLtMatrixLayout_t c_desc_{nullptr};
cublasLtMatrixLayout_t out_desc_{nullptr};
cublasLtMatmulHeuristicResult_t heuristic_;
};
} // namespace cu
namespace {
std::tuple<bool, int64_t, array>
check_transpose(std::vector<array>& copies, const Stream& s, const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1 && stx == arr.shape(-1)) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy);
}
}
} // namespace
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Matmul::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
assert(inputs.size() == 2);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
// Return 0s if either input is empty.
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero(0, a_pre.dtype());
encoder.add_temporary(zero);
fill_gpu(zero, out, s);
return;
}
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
int K = a_pre.shape(-1);
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
for (auto& temp : copies) {
encoder.add_temporary(temp);
}
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
cu::MatMul matmul(
encoder.device(),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
matmul.run(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc);
a_it.step();
b_it.step();
}
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("AddMM::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
assert(inputs.size() == 3);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto& c_pre = inputs[2];
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
int K = a_pre.shape(-1);
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
auto [c_transposed, ldc, c] = check_transpose(copies, s, c_pre);
for (auto& temp : copies) {
encoder.add_temporary(temp);
}
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides, c_batch_strides] =
collapse_batches(a, b, c);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
c_batch_strides.back() == M * c.strides()[c.ndim() - 2] &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
c_batch_strides = {0};
batch_shape = {1};
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
cu::MatMul matmul(
encoder.device(),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
c_transposed,
ldc,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back(),
c_batch_strides.back());
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
matmul.run(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
c.data<int8_t>() + c.itemsize() * c_it.loc,
alpha_,
beta_);
a_it.step();
b_it.step();
c_it.step();
}
}
} // namespace mlx::core

View File

@@ -1,7 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/dtype_utils.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/arange.cuh"
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
#include "mlx/distributed/primitives.h"
@@ -43,111 +43,76 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
});
}
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
#define NO_GPU_USE_FALLBACK(func) \
bool func::use_fallback(Stream s) { \
return true; \
} \
NO_GPU_MULTI(func)
#define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
NO_GPU(Abs)
NO_GPU(Add)
NO_GPU(AddMM)
NO_GPU(ArcCos)
NO_GPU(ArcCosh)
NO_GPU(ArcSin)
NO_GPU(ArcSinh)
NO_GPU(ArcTan)
NO_GPU(ArcTan2)
NO_GPU(ArcTanh)
NO_GPU(ArgPartition)
NO_GPU(ArgReduce)
NO_GPU(ArgSort)
NO_GPU(BitwiseBinary)
NO_GPU(BitwiseInvert)
NO_GPU(BlockMaskedMM)
NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
NO_GPU(Conjugate)
NO_GPU(Convolution)
NO_GPU(Cos)
NO_GPU(Cosh)
NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(Remainder)
NO_GPU(Equal)
NO_GPU(Erf)
NO_GPU(ErfInv)
NO_GPU(Exp)
NO_GPU(Expm1)
NO_GPU(FFT)
NO_GPU(Floor)
NO_GPU(Gather)
NO_GPU(GatherAxis)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Greater)
NO_GPU(GreaterEqual)
NO_GPU(Hadamard)
NO_GPU(Imag)
NO_GPU(Less)
NO_GPU(LessEqual)
NO_GPU(Load)
NO_GPU(Log)
NO_GPU(Log1p)
NO_GPU(LogicalNot)
NO_GPU(LogicalAnd)
NO_GPU(LogicalOr)
NO_GPU(LogAddExp)
NO_GPU(LogSumExp)
NO_GPU_MULTI(LUF)
NO_GPU(Matmul)
NO_GPU(Maximum)
NO_GPU(Minimum)
NO_GPU(Multiply)
NO_GPU(Negative)
NO_GPU(NotEqual)
NO_GPU(Partition)
NO_GPU(Power)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(RandomBits)
NO_GPU(Real)
NO_GPU(Reduce)
NO_GPU(Round)
NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(ScatterAxis)
NO_GPU(Select)
NO_GPU(Sigmoid)
NO_GPU(Sign)
NO_GPU(Sin)
NO_GPU(Sinh)
NO_GPU(SliceUpdate)
NO_GPU(Softmax)
NO_GPU(Sort)
NO_GPU(Square)
NO_GPU(Sqrt)
NO_GPU(Subtract)
NO_GPU_MULTI(SVD)
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_MULTI(LayerNorm)
NO_GPU_USE_FALLBACK(LayerNorm)
NO_GPU_MULTI(LayerNormVJP)
NO_GPU_MULTI(RMSNorm)
NO_GPU_USE_FALLBACK(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE)
NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)

196
mlx/backend/cuda/unary.cu Normal file
View File

@@ -0,0 +1,196 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernels/unary_ops.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
namespace mlx::core {
namespace cu {
template <typename Op, typename In, typename Out>
constexpr bool supports_unary_op() {
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
std::is_same_v<Op, Sign>) {
return std::is_same_v<In, Out>;
}
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcCosh> ||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Sigmoid> ||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseInvert>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor> ||
std::is_same_v<Op, Square>) {
return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>;
}
if (std::is_same_v<Op, Conjugate>) {
return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>;
}
if (std::is_same_v<Op, Cos> || std::is_same_v<Op, Cosh> ||
std::is_same_v<Op, Exp> || std::is_same_v<Op, Round> ||
std::is_same_v<Op, Sin> || std::is_same_v<Op, Sinh> ||
std::is_same_v<Op, Tan> || std::is_same_v<Op, Tanh>) {
return std::is_same_v<In, Out> &&
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
}
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>;
}
if (std::is_same_v<Op, LogicalNot>) {
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
}
return false;
}
} // namespace cu
template <typename Op>
void unary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s) {
auto& in = inputs[0];
if (in.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto policy = cu::thrust_policy(stream);
auto in_ptr = thrust::device_pointer_cast(in.data<InType>());
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
if (in.flags().contiguous) {
thrust::transform(
policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op());
} else {
auto [shape, strides] = collapse_contiguous_dims(in);
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
in_ptr, in.data_size(), shape, strides);
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
}
} else {
throw std::runtime_error(fmt::format(
"Can not do unary op {} on input of {} with output of {}.",
op,
dtype_to_string(in.dtype()),
dtype_to_string(out.dtype())));
}
});
});
});
}
template <typename Op>
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s) {
set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace<Op>(inputs, out, op, s);
}
#define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
unary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
}
UNARY_GPU(Abs)
UNARY_GPU(ArcCos)
UNARY_GPU(ArcCosh)
UNARY_GPU(ArcSin)
UNARY_GPU(ArcSinh)
UNARY_GPU(ArcTan)
UNARY_GPU(ArcTanh)
UNARY_GPU(BitwiseInvert)
UNARY_GPU(Ceil)
UNARY_GPU(Conjugate)
UNARY_GPU(Cos)
UNARY_GPU(Cosh)
UNARY_GPU(Erf)
UNARY_GPU(ErfInv)
UNARY_GPU(Exp)
UNARY_GPU(Expm1)
UNARY_GPU(Floor)
UNARY_GPU(Imag)
UNARY_GPU(Log1p)
UNARY_GPU(LogicalNot)
UNARY_GPU(Negative)
UNARY_GPU(Real)
UNARY_GPU(Sigmoid)
UNARY_GPU(Sign)
UNARY_GPU(Sin)
UNARY_GPU(Sinh)
UNARY_GPU(Square)
UNARY_GPU(Tan)
UNARY_GPU(Tanh)
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Log::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (base_) {
case Base::e:
unary_op_gpu<cu::Log>(inputs, out, op, s);
break;
case Base::two:
unary_op_gpu<cu::Log2>(inputs, out, op, s);
break;
case Base::ten:
unary_op_gpu<cu::Log10>(inputs, out, op, s);
break;
}
}
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Round::eval_gpu");
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto& s = out.primitive().stream();
if (issubdtype(in.dtype(), inexact)) {
unary_op_gpu<cu::Round>(inputs, out, get_primitive_string(this), s);
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Sort::eval_gpu");
auto& s = out.primitive().stream();
if (recip_) {
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);
} else {
unary_op_gpu<cu::Sqrt>(inputs, out, "Sqrt", s);
}
}
} // namespace mlx::core

View File

@@ -1,5 +1,7 @@
// Copyright © 2025 Apple Inc.
// This file include utilies that are used by C++ code (i.e. .cpp files).
#pragma once
#include <cuda_runtime.h>

View File

@@ -5,9 +5,17 @@
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#if defined(MLX_USE_CUDA)
#include <nvtx3/nvtx3.hpp>
#endif
#include <cassert>
#if defined(MLX_USE_CUDA)
#define MLX_PROFILER_RANGE(message) nvtx3::scoped_range r(message)
#else
#define MLX_PROFILER_RANGE(message)
#endif
namespace mlx::core {

View File

@@ -30,141 +30,18 @@ void* Buffer::raw_ptr() {
namespace metal {
namespace {
BufferCache::BufferCache(ResidencySet& residency_set)
: head_(nullptr),
tail_(nullptr),
pool_size_(0),
residency_set_(residency_set) {}
BufferCache::~BufferCache() {
auto pool = metal::new_scoped_memory_pool();
clear();
}
int BufferCache::clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
if (holder->buf) {
if (!holder->buf->heap()) {
residency_set_.erase(holder->buf);
}
holder->buf->release();
n_release++;
}
delete holder;
}
buffer_pool_.clear();
pool_size_ = 0;
head_ = nullptr;
tail_ = nullptr;
return n_release;
}
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
// Find the closest buffer in pool
MTL::Buffer* pbuf = nullptr;
auto it = buffer_pool_.lower_bound(size);
// Make sure we use most of the available memory
while (!pbuf && it != buffer_pool_.end() &&
it->first < std::min(2 * size, size + 2 * vm_page_size)) {
// Collect from the cache
pbuf = it->second->buf;
// Remove from cache
remove_from_list(it->second);
delete it->second;
it = buffer_pool_.erase(it);
}
if (pbuf) {
pool_size_ -= pbuf->length();
}
return pbuf;
}
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
// Add to cache
if (buf) {
BufferHolder* bh = new BufferHolder(buf);
add_at_head(bh);
pool_size_ += buf->length();
buffer_pool_.insert({buf->length(), bh});
}
}
int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) {
return clear();
} else {
int n_release = 0;
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
if (tail_->buf) {
total_bytes_freed += tail_->buf->length();
if (!tail_->buf->heap()) {
residency_set_.erase(tail_->buf);
}
tail_->buf->release();
tail_->buf = nullptr;
n_release++;
}
remove_from_list(tail_);
}
pool_size_ -= total_bytes_freed;
return n_release;
}
}
void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
if (!to_add)
return;
if (!head_) {
head_ = to_add;
tail_ = to_add;
} else {
head_->prev = to_add;
to_add->next = head_;
head_ = to_add;
}
}
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
if (!to_remove) {
return;
}
// If in the middle
if (to_remove->prev && to_remove->next) {
to_remove->prev->next = to_remove->next;
to_remove->next->prev = to_remove->prev;
} else if (to_remove->prev && to_remove == tail_) { // If tail
tail_ = to_remove->prev;
tail_->next = nullptr;
} else if (to_remove == head_ && to_remove->next) { // If head
head_ = to_remove->next;
head_->prev = nullptr;
} else if (to_remove == head_ && to_remove == tail_) { // If only element
head_ = nullptr;
tail_ = nullptr;
}
to_remove->prev = nullptr;
to_remove->next = nullptr;
}
} // namespace
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(residency_set_) {
buffer_cache_(
vm_page_size,
[](MTL::Buffer* buf) { return buf->length(); },
[this](MTL::Buffer* buf) {
if (!buf->heap()) {
residency_set_.erase(buf);
}
buf->release();
}) {
auto pool = metal::new_scoped_memory_pool();
auto memsize = std::get<size_t>(device_info().at("memory_size"));
auto max_rec_size =
@@ -193,6 +70,7 @@ MetalAllocator::~MetalAllocator() {
if (heap_) {
heap_->release();
}
buffer_cache_.clear();
}
size_t MetalAllocator::set_cache_limit(size_t limit) {

View File

@@ -7,6 +7,7 @@
#include <vector>
#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/resident.h"
@@ -14,43 +15,6 @@ namespace mlx::core::metal {
using allocator::Buffer;
namespace {
class BufferCache {
public:
BufferCache(ResidencySet& residency_set);
~BufferCache();
MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf);
int release_cached_buffers(size_t min_bytes_to_free);
size_t cache_size() {
return pool_size_;
}
int clear();
private:
struct BufferHolder {
public:
BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
BufferHolder* prev;
BufferHolder* next;
MTL::Buffer* buf;
};
void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove);
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_;
BufferHolder* tail_;
size_t pool_size_;
ResidencySet& residency_set_;
};
} // namespace
class MetalAllocator : public allocator::Allocator {
/** Allocator for Metal GPUs. */
public:
@@ -90,7 +54,7 @@ class MetalAllocator : public allocator::Allocator {
friend MetalAllocator& allocator();
// Caching allocator
BufferCache buffer_cache_;
BufferCache<MTL::Buffer> buffer_cache_;
ResidencySet residency_set_;

View File

@@ -31,13 +31,13 @@ std::string get_kernel_name(
kname = "ss";
break;
case BinaryOpType::ScalarVector:
kname = (large ? "sv2" : "sv");
kname = "sv";
break;
case BinaryOpType::VectorScalar:
kname = (large ? "vs2" : "vs");
kname = "vs";
break;
case BinaryOpType::VectorVector:
kname = (large ? "vv2" : "vv");
kname = "vv";
break;
case BinaryOpType::General:
kname = "g";
@@ -51,6 +51,13 @@ std::string get_kernel_name(
}
break;
}
if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) {
if (large) {
kname += "2";
} else if (work_per_thread > 1) {
kname += "n";
}
}
concatenate(kname, "_", op, type_to_name(a));
return kname;
}
@@ -90,7 +97,7 @@ void binary_op_gpu_inplace(
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > UINT32_MAX;
work_per_thread = get_work_per_thread(a.dtype());
work_per_thread = get_work_per_thread(a.dtype(), out.data_size());
}
std::string kernel_name =
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);

View File

@@ -11,8 +11,6 @@
#include "mlx/primitives.h"
#include "mlx/utils.h"
using namespace fmt::literals;
namespace mlx::core {
inline void build_kernel(
@@ -21,21 +19,12 @@ inline void build_kernel(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids,
const std::function<bool(size_t)>& is_constant,
bool contiguous,
int ndim,
bool dynamic_dims,
bool use_big_index = false,
int work_per_thread = 1) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer;
bool add_indices = false;
int cnt = 0;
@@ -45,14 +34,15 @@ inline void build_kernel(
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
// Add the input arguments
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants from the input list
if (is_constant(x)) {
if (is_constant(i)) {
continue;
}
const auto& x = inputs[i];
auto& xname = namer.get_name(x);
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
add_indices = true;
@@ -80,8 +70,6 @@ inline void build_kernel(
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os += fmt::format(
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
} else {
@@ -125,7 +113,7 @@ inline void build_kernel(
auto& x = inputs[i];
auto& xname = namer.get_name(x);
if (is_constant(x)) {
if (is_constant(i)) {
auto type_str = get_type_string(x.dtype());
std::ostringstream ss;
print_constant(ss, x);
@@ -271,11 +259,6 @@ inline void build_kernel(
void Compiled::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Make the name for the kernel library
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Get the kernel if someone else built it already
auto& s = stream();
auto& d = metal::device(s.device);
@@ -290,19 +273,33 @@ void Compiled::eval_gpu(
inputs_,
outputs_,
tape_,
constant_ids_,
is_constant_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ work_per_thread);
/* work_per_thread = */ 1);
if (work_per_thread > 1) {
build_kernel(
kernel,
kernel_lib_ + "_contiguous_n",
inputs_,
outputs_,
tape_,
is_constant_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ work_per_thread);
}
build_kernel(
kernel,
kernel_lib_ + "_contiguous_large",
inputs_,
outputs_,
tape_,
constant_ids_,
is_constant_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
@@ -315,7 +312,7 @@ void Compiled::eval_gpu(
inputs_,
outputs_,
tape_,
constant_ids_,
is_constant_,
/* contiguous = */ false,
/* ndim = */ i,
/* dynamic_dims = */ false,
@@ -328,7 +325,7 @@ void Compiled::eval_gpu(
inputs_,
outputs_,
tape_,
constant_ids_,
is_constant_,
/* contiguous = */ false,
/* ndim = */ i,
/* dynamic_dims = */ false,
@@ -342,7 +339,7 @@ void Compiled::eval_gpu(
inputs_,
outputs_,
tape_,
constant_ids_,
is_constant_,
/* contiguous = */ false,
/* ndim = */ 0,
/* dynamic_dims = */ true,
@@ -354,7 +351,7 @@ void Compiled::eval_gpu(
inputs_,
outputs_,
tape_,
constant_ids_,
is_constant_,
/* contiguous = */ false,
/* ndim = */ 0,
/* dynamic_dims = */ true,
@@ -363,81 +360,32 @@ void Compiled::eval_gpu(
return kernel;
});
// Figure out which kernel we are using
auto& output_shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, output_shape);
// Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
std::vector<Strides> initial_strides;
initial_strides.push_back(outputs[0].strides());
Shape shape;
std::vector<Strides> strides;
if (!contiguous) {
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
auto& x = inputs[i];
auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Skip scalar inputs.
if (is_scalar(x)) {
continue;
}
// Broadcast the inputs to the output shape.
Strides xstrides;
int j = 0;
for (; j < output_shape.size() - x.ndim(); j++) {
if (output_shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (output_shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
initial_strides.push_back(std::move(xstrides));
}
std::tie(shape, strides) =
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
}
bool large;
if (contiguous) {
size_t max_size = 0;
for (auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
large = (max_size > UINT32_MAX);
} else {
size_t max_size = 0;
for (auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
large = (max_size > UINT32_MAX);
}
// Whether to use large index.
bool large = compiled_use_large_index(inputs, outputs, contiguous);
// Get the kernel from the lib
int ndim = shape.size();
bool dynamic = ndim >= 8;
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
int work_per_thread = 1;
if (!contiguous) {
if (dynamic) {
kernel_name += "dynamic";
} else {
kernel_name += std::to_string(shape.size());
}
work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
} else {
work_per_thread =
get_work_per_thread(outputs[0].dtype(), outputs[0].data_size());
if (work_per_thread > 1 && !large) {
kernel_name += "_n";
}
}
if (large) {
kernel_name += "_large";
@@ -451,7 +399,7 @@ void Compiled::eval_gpu(
int stride_idx = 1; // idx 0 is the output strides
Strides in_strides;
for (int i = 0; i < inputs.size(); i++) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
if (is_constant_(i)) {
continue;
}
auto& x = inputs[i];
@@ -468,8 +416,7 @@ void Compiled::eval_gpu(
compute_encoder.set_vector_bytes(in_strides, cnt++);
}
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous);
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
// Put the outputs in
for (auto& x : outputs) {
@@ -478,7 +425,6 @@ void Compiled::eval_gpu(
// Put the output shape and strides in
if (!contiguous) {
compute_encoder.set_vector_bytes(strides[0], cnt++);
compute_encoder.set_vector_bytes(shape, cnt++);
} else {
auto size = outputs[0].data_size();
@@ -496,7 +442,6 @@ void Compiled::eval_gpu(
// Launch the kernel
if (contiguous) {
int work_per_thread = get_work_per_thread(outputs[0].dtype());
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
@@ -509,7 +454,6 @@ void Compiled::eval_gpu(
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1);
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
int pow2;

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <numeric>
@@ -178,83 +177,6 @@ void explicit_gemm_conv_group_ND_gpu(
/*copies = */ copies);
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
int groups,
bool flip) {
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(2)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int kdil[NDIM] = */ {wt_dilation[0]},
/* const int idil[NDIM] = */ {in_dilation[0]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
/* const int groups = */ groups,
/* const bool flip = */ flip};
// Direct to explicit gemm conv
if (groups > 1) {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
void slow_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
int bm = 16, bn = 8;
int tm = 4, tn = 4;
std::ostringstream kname;
kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn
<< "_tm" << tm << "_tn" << tn;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm);
size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn);
size_t grid_dim_z = conv_params.N;
MTL::Size group_dims = MTL::Size(bm, bn, 1);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_gpu(
const Stream& s,
metal::Device& d,
@@ -755,7 +677,7 @@ void depthwise_conv_2D_gpu(
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
@@ -771,6 +693,141 @@ void depthwise_conv_2D_gpu(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void dispatch_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params,
std::vector<array>& copies) {
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (is_idil_one && conv_params.groups > 1) {
const int C_per_group = conv_params.C / conv_params.groups;
const int O_per_group = conv_params.O / conv_params.groups;
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
}
}
// Direct to winograd conv
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
channels_large) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
// Direct to implicit gemm conv
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
}
// Direct to explicit gemm conv
else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
int groups,
bool flip,
std::vector<array>& copies) {
bool is_idil_one = in_dilation[0] == 1;
int C = in.shape(2);
int O = wt.shape(0);
const int C_per_group = in.shape(2) / groups;
const int O_per_group = wt.shape(0) / groups;
// Direct to implicit gemm conv
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
MLXConvParams<2> conv_params{
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ C,
/* const int O = */ O,
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1)), 1},
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1)), 1},
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1)), 1},
/* const int str[NDIM] = */ {wt_strides[0], 1},
/* const int pad[NDIM] = */ {padding[0], 0},
/* const int kdil[NDIM] = */ {wt_dilation[0], 1},
/* const int idil[NDIM] = */ {in_dilation[0], 1},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], 0, in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], 0, out.strides()[2]},
/* const int groups = */ groups,
/* const bool flip = */ flip};
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
return;
}
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(2)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int kdil[NDIM] = */ {wt_dilation[0]},
/* const int idil[NDIM] = */ {in_dilation[0]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
/* const int groups = */ groups,
/* const bool flip = */ flip};
// Direct to explicit gemm conv
if (groups > 1) {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
void conv_2D_gpu(
const Stream& s,
metal::Device& d,
@@ -808,57 +865,7 @@ void conv_2D_gpu(
/* const int groups = */ groups,
/* const bool flip = */ flip,
};
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (is_idil_one && groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
}
}
// Direct to winograd conv
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
channels_large) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
// Direct to implicit gemm conv
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
}
// Direct to explicit gemm conv
else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
void conv_3D_gpu(
@@ -952,7 +959,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_,
padding_lo_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
@@ -967,7 +974,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_,
padding_lo_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
@@ -983,12 +990,13 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_,
padding_lo_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
groups_,
flip_);
flip_,
copies);
}
// Throw error
else {

View File

@@ -55,10 +55,10 @@ void copy_gpu_inplace(
std::string kernel_name;
switch (ctype) {
case CopyType::Scalar:
kernel_name = (large ? "s2" : "s");
kernel_name = large ? "s2" : "s";
break;
case CopyType::Vector:
kernel_name = (large ? "v2" : "v");
kernel_name = large ? "v2" : "v";
break;
case CopyType::General:
kernel_name = "g";
@@ -85,7 +85,10 @@ void copy_gpu_inplace(
}
}
} else {
work_per_thread = get_work_per_thread(in.dtype());
work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
if (work_per_thread > 1) {
kernel_name += "n";
}
}
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
@@ -170,9 +173,10 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
}
out.set_data(allocator::malloc(out.nbytes()));
bool large = out.data_size() > UINT32_MAX;
int work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
auto& d = metal::device(s.device);
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
type_to_name(val) + type_to_name(out);
std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s");
concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out));
auto kernel = get_copy_kernel(d, kernel_name, val, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -180,7 +184,6 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1);
int work_per_thread = get_work_per_thread(val.dtype());
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) {

View File

@@ -1,12 +1,326 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include <regex>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include "mlx/utils.h"
namespace mlx::core::fast {
struct CustomKernelCache {
std::unordered_map<std::string, std::string> libraries;
};
static CustomKernelCache& cache() {
static CustomKernelCache cache_;
return cache_;
};
std::string write_signature(
std::string func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<std::string>& attributes,
const std::vector<CustomKernelShapeInfo>& shape_infos,
bool atomic_outputs) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 16384);
kernel_source += header;
// Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`.
if (!template_args.empty()) {
kernel_source += "template <";
int i = 0;
for (const auto& [name, arg] : template_args) {
std::string param_type;
if (std::holds_alternative<int>(arg)) {
param_type = "int";
} else if (std::holds_alternative<bool>(arg)) {
param_type = "bool";
} else if (std::holds_alternative<Dtype>(arg)) {
param_type = "typename";
}
if (i > 0) {
kernel_source += ", ";
}
kernel_source += param_type;
kernel_source += " ";
kernel_source += name;
i++;
}
kernel_source += ">\n";
}
kernel_source += "[[kernel]] void ";
kernel_source += func_name;
kernel_source += "(\n";
int index = 0;
constexpr int max_constant_array_size = 8;
// Add inputs
for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
auto dtype = get_type_string(arr.dtype());
std::string location =
arr.size() < max_constant_array_size ? "constant" : "device";
std::string ref = arr.ndim() == 0 ? "&" : "*";
kernel_source += " const ";
kernel_source += location;
kernel_source += " ";
kernel_source += dtype;
kernel_source += ref;
kernel_source += " ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]],\n";
index++;
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
kernel_source +=
(" const constant int* " + name + "_shape [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].strides) {
kernel_source +=
(" const constant int64_t* " + name + "_strides [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].ndim) {
kernel_source +=
(" const constant int& " + name + "_ndim [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
}
}
// Add outputs
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source += " device ";
auto type_string = get_type_string(dtype);
if (atomic_outputs) {
kernel_source += "atomic<";
}
kernel_source += type_string;
if (atomic_outputs) {
kernel_source += ">";
}
kernel_source += "* ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]]";
if (index < inputs.size() + output_names.size() - 1 ||
attributes.size() > 0) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
index = 0;
for (const auto& attr : attributes) {
kernel_source += attr;
if (index < attributes.size() - 1) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
kernel_source += source;
kernel_source += "\n}\n";
return kernel_source;
}
std::string write_template(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
std::ostringstream template_def;
template_def << "<";
int i = 0;
for (const auto& [name, arg] : template_args) {
if (i > 0) {
template_def << ", ";
}
if (std::holds_alternative<int>(arg)) {
template_def << std::get<int>(arg);
} else if (std::holds_alternative<bool>(arg)) {
template_def << std::get<bool>(arg);
} else if (std::holds_alternative<Dtype>(arg)) {
template_def << get_type_string(std::get<Dtype>(arg));
}
i++;
}
template_def << ">";
return template_def.str();
}
MetalKernelFunction metal_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header /* = "" */,
bool ensure_row_contiguous /* = true */,
bool atomic_outputs /* = false */) {
if (output_names.empty()) {
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
{"dispatch_quadgroups_per_threadgroup", "uint"},
{"dispatch_simdgroups_per_threadgroup", "uint"},
{"dispatch_threads_per_threadgroup", "uint3"},
{"grid_origin", "uint3"},
{"grid_size", "uint3"},
{"quadgroup_index_in_threadgroup", "uint"},
{"quadgroups_per_threadgroup", "uint"},
{"simdgroup_index_in_threadgroup", "uint"},
{"simdgroups_per_threadgroup", "uint"},
{"thread_execution_width", "uint"},
{"thread_index_in_quadgroup", "uint"},
{"thread_index_in_simdgroup", "uint"},
{"thread_index_in_threadgroup", "uint"},
{"thread_position_in_grid", "uint3"},
{"thread_position_in_threadgroup", "uint3"},
{"threadgroup_position_in_grid", "uint3"},
{"threadgroups_per_grid", "uint3"},
{"threads_per_grid", "uint3"},
{"threads_per_simdgroup", "uint"},
{"threads_per_threadgroup", "uint3"},
};
std::vector<std::string> attributes;
for (const auto& [attr, dtype] : metal_attributes) {
if (source.find(attr) != std::string::npos) {
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
}
}
return [=,
shape_infos = std::move(shape_infos),
attributes = std::move(attributes)](
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
}
std::string kernel_name = "custom_kernel_" + name;
std::string template_def = "";
if (!template_args.empty()) {
std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args);
auto template_hash =
std::regex_replace(template_def, disallowed_chars, "_");
template_hash.pop_back();
kernel_name += "_";
kernel_name += template_hash;
}
std::string kernel_source = write_signature(
kernel_name,
header,
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
attributes,
shape_infos,
atomic_outputs);
if (!template_args.empty()) {
template_def = kernel_name + template_def;
kernel_source += "\ntemplate [[host_name(\"";
kernel_source += kernel_name;
kernel_source += "\")]] [[kernel]] decltype(";
kernel_source += template_def;
kernel_source += ") ";
kernel_source += template_def;
kernel_source += ";\n";
}
if (verbose) {
std::cout << "Generated source code for `" << name << "`:" << std::endl
<< "```" << std::endl
<< kernel_source << std::endl
<< "```" << std::endl;
}
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
std::move(kernel_name),
std::move(kernel_source),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value),
std::move(inputs));
};
}
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@@ -39,9 +353,23 @@ void CustomKernel::eval_gpu(
}
auto& d = metal::device(s.device);
const auto& lib_name = name_;
auto lib =
d.get_library(lib_name, [this] { return metal::utils() + source_; });
{
// Clear kernels from the device library cache if needed
auto& kernel_cache = cache();
if (auto it = kernel_cache.libraries.find(name_);
it != kernel_cache.libraries.end()) {
if (it->second != source_) {
auto& d = metal::device(s.device);
d.clear_library(name_);
it->second = source_;
}
} else {
kernel_cache.libraries.emplace(name_, source_);
}
}
auto lib = d.get_library(name_, [this] { return metal::utils() + source_; });
auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -73,6 +401,16 @@ void CustomKernel::eval_gpu(
}
const auto [tx, ty, tz] = threadgroup_;
auto tg_size = tx * ty * tz;
auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();
if (tg_size > max_tg_size) {
std::ostringstream msg;
msg << "Thread group size (" << tg_size << ") is greater than "
<< " the maximum allowed threads per threadgroup (" << max_tg_size
<< ").";
throw std::invalid_argument(msg.str());
}
const auto [gx, gy, gz] = grid_;
MTL::Size group_dims =
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));

View File

@@ -295,7 +295,7 @@ void CommandEncoder::barrier() {
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
library_map_ = {{"mlx", load_default_library(device_)}};
default_library_ = load_default_library(device_);
arch_ = std::string(device_->architecture()->name()->utf8String());
auto arch = arch_.back();
switch (arch) {
@@ -326,11 +326,11 @@ Device::Device() {
Device::~Device() {
auto pool = new_scoped_memory_pool();
for (auto& k : kernel_map_) {
k.second->release();
}
for (auto& l : library_map_) {
l.second->release();
for (auto& [l, kernel_map] : library_kernels_) {
l->release();
for (auto& [_, k] : kernel_map) {
k->release();
}
}
stream_map_.clear();
device_->release();
@@ -474,13 +474,24 @@ CommandEncoder& Device::get_command_encoder(int index) {
return *stream.encoder;
}
void Device::register_library(
const std::string& lib_name,
const std::string& lib_path) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
library_map_.insert({lib_name, new_lib});
MTL::Library* Device::get_library(
const std::string& name,
const std::string& path /* = "" */) {
{
std::shared_lock rlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
std::unique_lock wlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
auto new_lib = load_library(device_, name, path.c_str());
library_map_.insert({name, new_lib});
return new_lib;
}
MTL::Library* Device::build_library_(const std::string& source_string) {
@@ -649,6 +660,19 @@ MTL::Library* Device::get_library(
return mtl_lib;
}
void Device::clear_library(const std::string& name) {
std::unique_lock wlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
auto kernel_map_it = library_kernels_.find(it->second);
for (auto& [_, kernel] : kernel_map_it->second) {
kernel->release();
}
library_kernels_.erase(kernel_map_it);
it->second->release();
library_map_.erase(it);
}
}
MTL::LinkedFunctions* Device::get_linked_functions_(
const std::vector<MTL::Function*>& funcs) {
if (funcs.empty()) {
@@ -679,6 +703,7 @@ MTL::ComputePipelineState* Device::get_kernel_(
std::unique_lock wlock(kernel_mtx_);
// Try loading again to avoid loading twice
auto& kernel_map_ = library_kernels_[mtl_lib];
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
return it->second;
}
@@ -713,6 +738,7 @@ MTL::ComputePipelineState* Device::get_kernel(
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
auto& kernel_map_ = library_kernels_[mtl_lib];
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
@@ -722,23 +748,11 @@ MTL::ComputePipelineState* Device::get_kernel(
MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
{
// Multiple readers allowed
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
}
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_(lib_name);
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
return get_kernel(
base_name, default_library_, hash_name, func_consts, linked_functions);
}
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {

View File

@@ -95,6 +95,10 @@ struct CommandEncoder {
return enc_->setBytes(&v, sizeof(T), idx);
}
void set_threadgroup_memory_length(size_t length, int idx) {
enc_->setThreadgroupMemoryLength(length, idx);
}
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
}
@@ -183,14 +187,16 @@ class Device {
CommandEncoder& get_command_encoder(int index);
void end_encoding(int index);
void register_library(
const std::string& lib_name,
const std::string& lib_path = "");
MTL::Library* get_library(
const std::string& name,
const std::string& path = "");
MTL::Library* get_library(
const std::string& name,
const std::function<std::string(void)>& builder);
void clear_library(const std::string& name);
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
MTL::Library* mtl_lib,
@@ -200,7 +206,6 @@ class Device {
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& hash_name = "",
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});
@@ -254,10 +259,13 @@ class Device {
std::unordered_map<int32_t, DeviceStream> stream_map_;
std::shared_mutex kernel_mtx_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
std::shared_mutex library_mtx_;
std::unordered_map<std::string, MTL::Library*> library_map_;
MTL::Library* default_library_;
std::unordered_map<
MTL::Library*,
std::unordered_map<std::string, MTL::ComputePipelineState*>>
library_kernels_;
const MTL::ResidencySet* residency_set_{nullptr};
std::string arch_;
int max_ops_per_buffer_;

View File

@@ -632,7 +632,7 @@ void fft_op(
func_consts.push_back(make_int(&rader_m, 3));
// The overall number of FFTs we're going to compute for this input
int size = out.dtype() == float32 ? out.size() : in.size();
size_t size = out.dtype() == float32 ? out.size() : in.size();
if (real && inverse && four_step_params.required) {
size = out.size();
}
@@ -659,8 +659,6 @@ void fft_op(
// We can perform 2 RFFTs at once so the batch size is halved.
batch_size = (batch_size + 2 - 1) / 2;
}
int out_buffer_size = out.size();
auto& compute_encoder = d.get_command_encoder(s.index);
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
auto out_type_str = out.dtype() == float32 ? "float" : "float2";

View File

@@ -2,6 +2,7 @@
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
@@ -458,17 +459,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 2);
// Set source info
auto shape = idx.shape();
shape.erase(shape.begin() + axis_);
compute_encoder.set_vector_bytes(shape, 3);
auto strides = src.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 4);
strides = idx.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 5);
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
compute_encoder.set_vector_bytes(remove_index(src.strides(), axis_), 4);
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
compute_encoder.set_bytes(ndim - 1, 6);
compute_encoder.set_bytes(axis_, 7);
compute_encoder.set_bytes(src.shape(axis_), 8);
@@ -582,17 +575,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 2);
// Set source info
auto shape = idx.shape();
shape.erase(shape.begin() + axis_);
compute_encoder.set_vector_bytes(shape, 3);
auto strides = upd.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 4);
strides = idx.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 5);
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
compute_encoder.set_bytes(ndim - 1, 6);
compute_encoder.set_bytes(axis_, 7);
compute_encoder.set_bytes(out.shape(axis_), 8);

View File

@@ -41,7 +41,11 @@ MTL::ComputePipelineState* get_unary_kernel(
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::unary_ops(), metal::unary());
kernel_source +=
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1);
if (get_work_per_thread(in_type) > 1) {
kernel_source +=
get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op);
}
kernel_source +=
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source += get_template_definition(
@@ -59,11 +63,8 @@ void append_binary_kernels(
Dtype out_type,
const std::string op,
std::string& kernel_source) {
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
{"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
{"vv", "binary_vv"},
{"vs2", "binary_vs2"},
{"sv2", "binary_sv2"},
{"vv2", "binary_vv2"},
@@ -78,6 +79,22 @@ void append_binary_kernels(
kernel_source +=
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
}
kernel_source += get_template_definition(
"vs_" + lib_name, "binary_vs", in_t, out_t, op, 1);
kernel_source += get_template_definition(
"sv_" + lib_name, "binary_sv", in_t, out_t, op, 1);
kernel_source += get_template_definition(
"vv_" + lib_name, "binary_vv", in_t, out_t, op, 1);
if (get_work_per_thread(in_type) > 1) {
kernel_source += get_template_definition(
"vsn_" + lib_name, "binary_vs", in_t, out_t, op);
kernel_source += get_template_definition(
"svn_" + lib_name, "binary_sv", in_t, out_t, op);
kernel_source += get_template_definition(
"vvn_" + lib_name, "binary_vv", in_t, out_t, op);
}
kernel_source += get_template_definition(
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
kernel_source += get_template_definition(
@@ -133,8 +150,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
auto t_str = get_type_string(type);
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"},
const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{
{"v2", "ternary_v2"},
{"g1large", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"},
@@ -144,6 +160,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
kernel_source +=
get_template_definition(name + "_" + lib_name, func, t_str, op);
}
if (get_work_per_thread(type) > 1) {
kernel_source +=
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op);
}
kernel_source +=
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
kernel_source += get_template_definition(
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
kernel_source += get_template_definition(
@@ -170,15 +193,22 @@ MTL::ComputePipelineState* get_copy_kernel(
kernel_source += metal::copy();
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source +=
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
kernel_source += get_template_definition(
"s_" + lib_name, "copy_s", in_type, out_type, 1);
kernel_source +=
get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type);
kernel_source +=
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
kernel_source += get_template_definition(
"v_" + lib_name, "copy_v", in_type, out_type, 1);
kernel_source +=
get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type);
if (get_work_per_thread(out.dtype()) > 1) {
kernel_source += get_template_definition(
"sn_" + lib_name, "copy_s", in_type, out_type);
kernel_source += get_template_definition(
"vn_" + lib_name, "copy_v", in_type, out_type);
}
kernel_source += get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
kernel_source += get_template_definition(

View File

@@ -80,9 +80,10 @@ template <typename T, typename Op, int N_READS = 4>
const constant size_t& ndim [[buffer(5)]],
const constant int64_t& axis_stride [[buffer(6)]],
const constant size_t& axis_size [[buffer(7)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 gsize [[threads_per_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
@@ -104,17 +105,18 @@ template <typename T, typename Op, int N_READS = 4>
// Compute the input/output index. There is one beginning and one output for
// the whole threadgroup.
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
int64_t row_idx = gid.y + static_cast<int64_t>(gsize.y) * gid.z;
auto in_idx = elem_to_loc(row_idx, shape, in_strides, ndim);
auto out_idx = elem_to_loc(row_idx, shape, out_strides, ndim);
IndexValPair<T> best{0, Op::init};
threadgroup IndexValPair<T> local_data[32];
// Loop over the reduction axis in lsize*N_READS buckets
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) {
// Read the current value
uint32_t current_index = r * lsize * N_READS + lid * N_READS;
uint32_t current_index = r * lsize.x * N_READS + lid.x * N_READS;
uint32_t offset = current_index;
const device T* current_in = in + in_idx + current_index * axis_stride;
T vals[N_READS];
@@ -144,7 +146,7 @@ template <typename T, typename Op, int N_READS = 4>
}
// Read the appropriate value from local data and perform one simd reduction
uint simd_groups = ceildiv(lsize, simd_size);
uint simd_groups = ceildiv(lsize.x, simd_size);
if (simd_lane_id < simd_groups) {
best = local_data[simd_lane_id];
}
@@ -154,7 +156,7 @@ template <typename T, typename Op, int N_READS = 4>
}
// Finally write the output
if (lid == 0) {
if (lid.x == 0) {
out[out_idx] = best.index;
}
}

View File

@@ -17,8 +17,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
}
}
}
@@ -30,8 +36,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
}
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
}
}
}
@@ -43,8 +55,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
}
}
}
@@ -57,8 +75,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
}
}
}
@@ -71,8 +95,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
}
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
}
}
}
@@ -85,8 +115,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
}
}
}

View File

@@ -9,11 +9,16 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \
#define instantiate_binary_work_per_thread(op, tname, itype, otype) \
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) \
#define instantiate_binary_base(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
@@ -26,15 +31,19 @@
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_all(op, int64, int64_t, int64_t)
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_binary_base(op, tname, itype, otype) \
instantiate_binary_work_per_thread(op, tname, itype, otype)
#define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_base(op, int64, int64_t, int64_t)
#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \
@@ -44,7 +53,7 @@
#define instantiate_binary_types(op) \
instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_integer(op) \
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
instantiate_binary_base(op, complex64, complex64_t, complex64_t)\
instantiate_binary_float(op)
#define instantiate_binary_types_bool(op) \
@@ -52,15 +61,15 @@
instantiate_binary_all(op, uint8, uint8_t, bool) \
instantiate_binary_all(op, uint16, uint16_t, bool) \
instantiate_binary_all(op, uint32, uint32_t, bool) \
instantiate_binary_all(op, uint64, uint64_t, bool) \
instantiate_binary_base(op, uint64, uint64_t, bool) \
instantiate_binary_all(op, int8, int8_t, bool) \
instantiate_binary_all(op, int16, int16_t, bool) \
instantiate_binary_all(op, int32, int32_t, bool) \
instantiate_binary_all(op, int64, int64_t, bool) \
instantiate_binary_base(op, int64, int64_t, bool) \
instantiate_binary_all(op, float16, half, bool) \
instantiate_binary_all(op, float32, float, bool) \
instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \
instantiate_binary_all(op, complex64, complex64_t, bool)
instantiate_binary_base(op, complex64, complex64_t, bool)
instantiate_binary_types(Add)
instantiate_binary_types(Divide)
@@ -71,7 +80,7 @@ instantiate_binary_types_bool(Less)
instantiate_binary_types_bool(LessEqual)
instantiate_binary_types_bool(NotEqual)
instantiate_binary_float(LogAddExp)
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
instantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t)
instantiate_binary_types(Maximum)
instantiate_binary_types(Minimum)
instantiate_binary_types(Multiply)
@@ -84,7 +93,7 @@ instantiate_binary_float(ArcTan2)
instantiate_binary_all(NaNEqual, float16, half, bool)
instantiate_binary_all(NaNEqual, float32, float, bool)
instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)
instantiate_binary_all(NaNEqual, complex64, complex64_t, bool)
instantiate_binary_base(NaNEqual, complex64, complex64_t, bool)
instantiate_binary_all(LogicalOr, bool_, bool, bool)
instantiate_binary_all(LogicalAnd, bool_, bool, bool)

View File

@@ -21,10 +21,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
}
@@ -37,10 +45,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0];
d[index + i] = out[1];
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
}
@@ -53,10 +69,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
}
@@ -69,11 +93,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
auto out = Op()(a[0], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
auto out = Op()(a[0], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[0], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
}
@@ -86,11 +118,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
auto out = Op()(a[offset + i], b[0]);
c[offset + i] = out[0];
d[offset + i] = out[1];
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
auto out = Op()(a[offset + i], b[0]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[offset + i], b[0]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
}
@@ -103,11 +143,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
auto out = Op()(a[offset + i], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
auto out = Op()(a[offset + i], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[offset + i], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
}

View File

@@ -7,11 +7,16 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \
#define instantiate_binary_work_per_thread(op, tname, itype, otype) \
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op)
#define instantiate_binary_base(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
@@ -24,22 +29,26 @@
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_binary_base(op, tname, itype, otype) \
instantiate_binary_work_per_thread(op, tname, itype, otype)
#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \
instantiate_binary_all(op, float32, float, float) \
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
#define instantiate_binary_types(op) \
instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_all(op, int64, int64_t, int64_t) \
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
#define instantiate_binary_types(op) \
instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_base(op, int64, int64_t, int64_t) \
instantiate_binary_base(op, complex64, complex64_t, complex64_t) \
instantiate_binary_float(op)
instantiate_binary_types(DivMod) // clang-format on

View File

@@ -1,52 +1,76 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename U, int N = WorkPerThread<T>::n>
template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_s(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
dst[index + i] = static_cast<U>(src[0]);
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
dst[index + i] = static_cast<U>(src[0]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[index + i] = static_cast<U>(src[0]);
}
}
}
template <typename T, typename U, int N = WorkPerThread<T>::n>
template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_v(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
dst[index + i] = static_cast<U>(src[index + i]);
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
dst[index + i] = static_cast<U>(src[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[index + i] = static_cast<U>(src[index + i]);
}
}
}
template <typename T, typename U, int N = WorkPerThread<T>::n>
template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_s2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
}
}
}
template <typename T, typename U, int N = WorkPerThread<T>::n>
template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_v2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]);
}
}
}

View File

@@ -4,9 +4,13 @@
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy_all(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
#define instantiate_copy_work_per_thread(tname, itype, otype) \
instantiate_kernel("sn_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("vn_copy" #tname, copy_v, itype, otype)
#define instantiate_copy_base(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype, 1) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype, 1) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
@@ -18,6 +22,10 @@
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4)
#define instantiate_copy_all(tname, itype, otype) \
instantiate_copy_base(tname, itype, otype) \
instantiate_copy_work_per_thread(tname, itype, otype)
#define instantiate_copy_same(tname, type) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
@@ -42,15 +50,15 @@
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
instantiate_copy_base(itname ##uint64, itype, uint64_t) \
instantiate_copy_all(itname ##int8, itype, int8_t) \
instantiate_copy_all(itname ##int16, itype, int16_t) \
instantiate_copy_all(itname ##int32, itype, int32_t) \
instantiate_copy_all(itname ##int64, itype, int64_t) \
instantiate_copy_base(itname ##int64, itype, int64_t) \
instantiate_copy_all(itname ##float16, itype, half) \
instantiate_copy_all(itname ##float32, itype, float) \
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
instantiate_copy_all(itname ##complex64, itype, complex64_t)
instantiate_copy_base(itname ##complex64, itype, complex64_t)
instantiate_copy_itype(bool_, bool)
instantiate_copy_itype(uint8, uint8_t)

View File

@@ -98,7 +98,7 @@ struct ReadWriter {
}
METAL_FUNC void load() const {
int batch_idx = elem.x * grid.y * n;
size_t batch_idx = size_t(elem.x * grid.y) * n;
short tg_idx = elem.y * grid.z + elem.z;
short max_index = grid.y * n - 2;
@@ -121,7 +121,7 @@ struct ReadWriter {
}
METAL_FUNC void write() const {
int batch_idx = elem.x * grid.y * n;
size_t batch_idx = size_t(elem.x * grid.y) * n;
short tg_idx = elem.y * grid.z + elem.z;
short max_index = grid.y * n - 2;
@@ -144,7 +144,7 @@ struct ReadWriter {
// Padded IO for Bluestein's algorithm
METAL_FUNC void load_padded(int length, const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length + elem.y * length;
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
int fft_idx = elem.z;
int m = grid.z;
@@ -161,7 +161,7 @@ struct ReadWriter {
}
METAL_FUNC void write_padded(int length, const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length + elem.y * length;
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
int fft_idx = elem.z;
int m = grid.z;
float2 inv_factor = {1.0f / n, -1.0f / n};
@@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
template <>
METAL_FUNC void ReadWriter<float, float2>::load() const {
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@@ -283,7 +283,8 @@ template <>
METAL_FUNC void ReadWriter<float, float2>::write() const {
short n_over_2 = (n / 2) + 1;
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
size_t batch_idx =
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
int grid_index = elem.x * grid.y + elem.y;
@@ -317,7 +318,7 @@ template <>
METAL_FUNC void ReadWriter<float, float2>::load_padded(
int length,
const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@@ -345,8 +346,8 @@ METAL_FUNC void ReadWriter<float, float2>::write_padded(
int length,
const device float2* w_k) const {
int length_over_2 = (length / 2) + 1;
int batch_idx =
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
size_t batch_idx =
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
int grid_index = elem.x * grid.y + elem.y;
@@ -397,7 +398,8 @@ METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
template <>
METAL_FUNC void ReadWriter<float2, float>::load() const {
short n_over_2 = (n / 2) + 1;
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
size_t batch_idx =
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@@ -458,8 +460,8 @@ METAL_FUNC void ReadWriter<float2, float>::load_padded(
int n_over_2 = (n / 2) + 1;
int length_over_2 = (length / 2) + 1;
int batch_idx =
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
size_t batch_idx =
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@@ -503,7 +505,7 @@ template <>
METAL_FUNC void ReadWriter<float2, float>::write_padded(
int length,
const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
int grid_index = elem.x * grid.y + elem.y;

View File

@@ -9,7 +9,41 @@ using namespace metal;
constant bool has_w [[function_constant(20)]];
template <typename T, int N_READS = RMS_N_READS>
template <int N = 1>
inline void initialize_buffer(
threadgroup float* xs,
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
if (simd_group_id == 0) {
for (int i = 0; i < N; i++) {
xs[N * simd_lane_id + i] = 0;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
template <int N = 1>
inline void threadgroup_sum(
thread float* x,
threadgroup float* xs,
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
for (int i = 0; i < N; i++) {
x[i] = simd_sum(x[i]);
}
if (simd_lane_id == 0) {
for (int i = 0; i < N; i++) {
xs[N * simd_group_id + i] = x[i];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N; i++) {
x[i] = xs[N * simd_lane_id + i];
x[i] = simd_sum(x[i]);
}
}
template <typename T, int N_READS = 8>
[[kernel]] void layer_norm_single_row(
const device T* x,
const device T* w,
@@ -23,90 +57,71 @@ template <typename T, int N_READS = RMS_N_READS>
uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
float sumx = 0;
float sumx2 = 0;
float thread_x[N_READS];
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx[SIMD_SIZE];
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
// Initialize the registers and threadgroup memory
float thread_x[N_READS] = {0};
threadgroup float local_buffer[SIMD_SIZE] = {0};
initialize_buffer(local_buffer, simd_lane_id, simd_group_id);
// Advance the pointers
x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
b += b_stride * lid * N_READS;
out += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
// Compute some variables for reading writing etc
const bool safe = lid * N_READS + N_READS <= axis_size;
const int n = axis_size - lid * N_READS;
// Read the inputs
if (safe) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = x[i];
sumx2 += thread_x[i] * thread_x[i];
sumx += thread_x[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = x[i];
sumx2 += thread_x[i] * thread_x[i];
sumx += thread_x[i];
}
for (int i = 0; i < n; i++) {
thread_x[i] = x[i];
}
}
sumx = simd_sum(sumx);
sumx2 = simd_sum(sumx2);
// Initialize shared memory
if (simd_group_id == 0) {
local_sumx[simd_lane_id] = 0;
local_sumx2[simd_lane_id] = 0;
// Compute the mean
float mean = 0;
for (int i = 0; i < N_READS; i++) {
mean += thread_x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
mean /= axis_size;
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sumx[simd_group_id] = sumx;
local_sumx2[simd_group_id] = sumx2;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
sumx = simd_sum(local_sumx[simd_lane_id]);
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
if (simd_lane_id == 0) {
float mean = sumx / axis_size;
float variance = sumx2 / axis_size - mean * mean;
local_mean[0] = mean;
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
// Compute the normalizer
float normalizer = 0;
if (!safe) {
for (int i = n; i < N_READS; i++) {
thread_x[i] = mean;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = local_mean[0];
float normalizer = local_normalizer[0];
for (int i = 0; i < N_READS; i++) {
thread_x[i] -= mean;
normalizer += thread_x[i] * thread_x[i];
}
threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id);
normalizer = metal::precise::rsqrt(normalizer / axis_size + eps);
// Write the outputs
out += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
if (safe) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
thread_x[i] *= normalizer;
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
out[i] =
w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
}
for (int i = 0; i < n; i++) {
thread_x[i] *= normalizer;
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
}
}
}
template <typename T, int N_READS = RMS_N_READS>
template <typename T, int N_READS = 4>
[[kernel]] void layer_norm_looped(
const device T* x,
const device T* w,
@@ -121,71 +136,52 @@ template <typename T, int N_READS = RMS_N_READS>
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
float sumx = 0;
float sumx2 = 0;
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx[SIMD_SIZE];
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
threadgroup float local_buffer[SIMD_SIZE];
initialize_buffer(local_buffer, simd_lane_id, simd_group_id);
x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
b += b_stride * lid * N_READS;
// Compute the mean
float mean = 0;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = x[i + r];
sumx2 += xi * xi;
sumx += xi;
mean += x[i + r];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = x[i + r];
sumx2 += xi * xi;
sumx += xi;
mean += x[i + r];
}
}
}
}
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
mean /= axis_size;
sumx = simd_sum(sumx);
sumx2 = simd_sum(sumx2);
// Initialize shared memory
if (simd_group_id == 0) {
local_sumx[simd_lane_id] = 0;
local_sumx2[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sumx[simd_group_id] = sumx;
local_sumx2[simd_group_id] = sumx2;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
sumx = simd_sum(local_sumx[simd_lane_id]);
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
if (simd_lane_id == 0) {
float mean = sumx / axis_size;
float variance = sumx2 / axis_size - mean * mean;
local_mean[0] = mean;
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
// Compute the normalizer
float normalizer = 0;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float t = x[i + r] - mean;
normalizer += t * t;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float t = x[i + r] - mean;
normalizer += t * t;
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = local_mean[0];
float normalizer = local_normalizer[0];
threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id);
normalizer = metal::precise::rsqrt(normalizer / axis_size + eps);
// Write the outputs
out += gid * size_t(axis_size) + lid * N_READS;
@@ -208,7 +204,7 @@ template <typename T, int N_READS = RMS_N_READS>
}
}
template <typename T, int N_READS = RMS_N_READS>
template <typename T, int N_READS = 8>
[[kernel]] void vjp_layer_norm_single_row(
const device T* x,
const device T* w,
@@ -222,133 +218,96 @@ template <typename T, int N_READS = RMS_N_READS>
uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
// Advance the input pointers
x += gid * size_t(axis_size) + lid * N_READS;
g += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the computation and accumulators
float thread_x[N_READS];
float thread_w[N_READS];
float thread_g[N_READS];
float sumx = 0;
float sumx2 = 0;
float sumwg = 0;
float sumwgx = 0;
// Initialize the registers and threadgroup memory
float thread_x[N_READS] = {0};
float thread_w[N_READS] = {0};
float thread_g[N_READS] = {0};
threadgroup float local_buffer[3 * SIMD_SIZE];
initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id);
constexpr int SIMD_SIZE = 32;
// Compute some variables for reading writing etc
const bool safe = lid * N_READS + N_READS <= axis_size;
const int n = axis_size - lid * N_READS;
threadgroup float local_sumx[SIMD_SIZE];
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_sumwg[SIMD_SIZE];
threadgroup float local_sumwgx[SIMD_SIZE];
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
threadgroup float local_meanwg[1];
threadgroup float local_meanwgx[1];
if (lid * N_READS + N_READS <= axis_size) {
// Read the inputs
if (safe) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = x[i];
thread_w[i] = w[i * w_stride];
thread_g[i] = g[i];
float wg = thread_w[i] * thread_g[i];
sumx += thread_x[i];
sumx2 += thread_x[i] * thread_x[i];
sumwg += wg;
sumwgx += wg * thread_x[i];
thread_w[i] = w[i * w_stride];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = x[i];
thread_w[i] = w[i * w_stride];
thread_g[i] = g[i];
float wg = thread_w[i] * thread_g[i];
sumx += thread_x[i];
sumx2 += thread_x[i] * thread_x[i];
sumwg += wg;
sumwgx += wg * thread_x[i];
}
for (int i = 0; i < n; i++) {
thread_x[i] = x[i];
thread_g[i] = g[i];
thread_w[i] = w[i * w_stride];
}
}
sumx = simd_sum(sumx);
sumx2 = simd_sum(sumx2);
sumwg = simd_sum(sumwg);
sumwgx = simd_sum(sumwgx);
// Initialize shared memory
if (simd_group_id == 0) {
local_sumx[simd_lane_id] = 0;
local_sumx2[simd_lane_id] = 0;
local_sumwg[simd_lane_id] = 0;
local_sumwgx[simd_lane_id] = 0;
// Compute the mean
float mean = 0;
for (int i = 0; i < N_READS; i++) {
mean += thread_x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
mean /= axis_size;
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sumx[simd_group_id] = sumx;
local_sumx2[simd_group_id] = sumx2;
local_sumwg[simd_group_id] = sumwg;
local_sumwgx[simd_group_id] = sumwgx;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
sumx = simd_sum(local_sumx[simd_lane_id]);
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
sumwg = simd_sum(local_sumwg[simd_lane_id]);
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
if (simd_lane_id == 0) {
float mean = sumx / axis_size;
float variance = sumx2 / axis_size - mean * mean;
local_mean[0] = mean;
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
local_meanwg[0] = sumwg / axis_size;
local_meanwgx[0] = sumwgx / axis_size;
// Compute the neccesary scaling factors using the mean
if (!safe) {
for (int i = n; i < N_READS; i++) {
thread_x[i] = mean;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = local_mean[0];
float normalizer = local_normalizer[0];
float meanwg = local_meanwg[0];
float meanwgxc = local_meanwgx[0] - meanwg * mean;
float normalizer2 = normalizer * normalizer;
float factors[3] = {0};
constexpr int meanwg = 0;
constexpr int meanwgxc = 1;
constexpr int normalizer2 = 2;
for (int i = 0; i < N_READS; i++) {
thread_x[i] -= mean;
factors[meanwg] += thread_w[i] * thread_g[i];
factors[meanwgxc] += thread_w[i] * thread_g[i] * thread_x[i];
factors[normalizer2] += thread_x[i] * thread_x[i];
}
threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id);
factors[meanwg] /= axis_size;
factors[meanwgxc] /= axis_size;
factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps);
float normalizer = metal::precise::sqrt(factors[normalizer2]);
// Write the outputs
gx += gid * size_t(axis_size) + lid * N_READS;
gw += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
if (safe) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
thread_x[i] *= normalizer;
gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) -
thread_x[i] * factors[meanwgxc] * factors[normalizer2]);
if (has_w) {
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
if (has_w) {
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
for (int i = 0; i < n; i++) {
thread_x[i] *= normalizer;
gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) -
thread_x[i] * factors[meanwgxc] * factors[normalizer2]);
if (has_w) {
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
}
}
}
template <typename T, int N_READS = RMS_N_READS>
template <typename T, int N_READS = 4>
[[kernel]] void vjp_layer_norm_looped(
const device T* x,
const device T* w,
@@ -363,102 +322,69 @@ template <typename T, int N_READS = RMS_N_READS>
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
// Advance the input pointers
x += gid * size_t(axis_size) + lid * N_READS;
g += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the accumulators
float sumx = 0;
float sumx2 = 0;
float sumwg = 0;
float sumwgx = 0;
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx[SIMD_SIZE];
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_sumwg[SIMD_SIZE];
threadgroup float local_sumwgx[SIMD_SIZE];
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
threadgroup float local_meanwg[1];
threadgroup float local_meanwgx[1];
threadgroup float local_buffer[3 * SIMD_SIZE];
initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id);
// Compute the mean
float mean = 0;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = x[i + r];
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
float wg = wi * gi;
sumx += xi;
sumx2 += xi * xi;
sumwg += wg;
sumwgx += wg * xi;
mean += x[i + r];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = x[i + r];
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
float wg = wi * gi;
sumx += xi;
sumx2 += xi * xi;
sumwg += wg;
sumwgx += wg * xi;
mean += x[i + r];
}
}
}
}
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
mean /= axis_size;
sumx = simd_sum(sumx);
sumx2 = simd_sum(sumx2);
sumwg = simd_sum(sumwg);
sumwgx = simd_sum(sumwgx);
// Initialize shared memory
if (simd_group_id == 0) {
local_sumx[simd_lane_id] = 0;
local_sumx2[simd_lane_id] = 0;
local_sumwg[simd_lane_id] = 0;
local_sumwgx[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sumx[simd_group_id] = sumx;
local_sumx2[simd_group_id] = sumx2;
local_sumwg[simd_group_id] = sumwg;
local_sumwgx[simd_group_id] = sumwgx;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
sumx = simd_sum(local_sumx[simd_lane_id]);
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
sumwg = simd_sum(local_sumwg[simd_lane_id]);
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
if (simd_lane_id == 0) {
float mean = sumx / axis_size;
float variance = sumx2 / axis_size - mean * mean;
local_mean[0] = mean;
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
local_meanwg[0] = sumwg / axis_size;
local_meanwgx[0] = sumwgx / axis_size;
// Compute the neccesary scaling factors using the mean
float factors[3] = {0};
constexpr int meanwg = 0;
constexpr int meanwgxc = 1;
constexpr int normalizer2 = 2;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float t = x[i + r] - mean;
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
float wg = wi * gi;
factors[meanwg] += wg;
factors[meanwgxc] += wg * t;
factors[normalizer2] += t * t;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float t = x[i + r] - mean;
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
float wg = wi * gi;
factors[meanwg] += wg;
factors[meanwgxc] += wg * t;
factors[normalizer2] += t * t;
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = local_mean[0];
float normalizer = local_normalizer[0];
float meanwg = local_meanwg[0];
float meanwgxc = local_meanwgx[0] - meanwg * mean;
float normalizer2 = normalizer * normalizer;
threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id);
factors[meanwg] /= axis_size;
factors[meanwgxc] /= axis_size;
factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps);
float normalizer = metal::precise::sqrt(factors[normalizer2]);
// Write the outputs
gx += gid * size_t(axis_size) + lid * N_READS;
@@ -470,7 +396,8 @@ template <typename T, int N_READS = RMS_N_READS>
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
gx[i + r] = static_cast<T>(
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
normalizer * (wi * gi - factors[meanwg]) -
xi * factors[meanwgxc] * factors[normalizer2]);
if (has_w) {
gw[i + r] = static_cast<T>(gi * xi);
}
@@ -482,7 +409,8 @@ template <typename T, int N_READS = RMS_N_READS>
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
gx[i + r] = static_cast<T>(
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
normalizer * (wi * gi - factors[meanwg]) -
xi * factors[meanwgxc] * factors[normalizer2]);
if (has_w) {
gw[i + r] = static_cast<T>(gi * xi);
}

View File

@@ -103,8 +103,8 @@ template <typename T, typename AccT = float, int N_READS = 4>
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
vals[i] =
(offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
}
}
prevmax = maxval;
@@ -134,10 +134,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
if (lid == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}

View File

@@ -14,11 +14,23 @@ using namespace metal;
MLX_MTL_CONST int SIMD_SIZE = 32;
MLX_MTL_CONST int QUAD_SIZE = 4;
template <int bits, int wsize = 8>
inline constexpr short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector(const device T* x, thread U* x_thread) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
U sum = 0;
@@ -57,6 +69,21 @@ inline U load_vector(const device T* x, thread U* x_thread) {
}
}
else if (bits == 5) {
for (int i = 0; i < values_per_thread; i += 8) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
x[i + 6] + x[i + 7];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 32.0f;
x_thread[i + 2] = x[i + 2] / 4.0f;
x_thread[i + 3] = x[i + 3] / 128.0f;
x_thread[i + 4] = x[i + 4] / 16.0f;
x_thread[i + 5] = x[i + 5] / 2.0f;
x_thread[i + 6] = x[i + 6] / 64.0f;
x_thread[i + 7] = x[i + 7] / 8.0f;
}
}
else if (bits == 6) {
for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
@@ -80,8 +107,9 @@ inline U load_vector(const device T* x, thread U* x_thread) {
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
U sum = 0;
@@ -121,6 +149,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
}
}
else if (bits == 5) {
for (int i = 0; i < N; i += 8) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
x[i + 6] + x[i + 7];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 32.0f;
x_thread[i + 2] = x[i + 2] / 4.0f;
x_thread[i + 3] = x[i + 3] / 128.0f;
x_thread[i + 4] = x[i + 4] / 16.0f;
x_thread[i + 5] = x[i + 5] / 2.0f;
x_thread[i + 6] = x[i + 6] / 64.0f;
x_thread[i + 7] = x[i + 7] / 8.0f;
}
}
else if (bits == 6) {
for (int i = 0; i < N; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
@@ -153,8 +196,9 @@ inline U qdot(
U bias,
U sum) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
U accum = 0;
@@ -199,6 +243,26 @@ inline U qdot(
}
}
else if (bits == 5) {
for (int i = 0; i < (values_per_thread / 8); i++) {
x_thread += 8 * i;
w += 5 * i;
accum += (w[0] & 0x1f) * x_thread[0];
accum += (w[0] & 0xe0) * x_thread[1];
accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
accum += (w[1] & 0x7c) * x_thread[2];
accum += (w[1] & 0x80) * x_thread[3];
accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
accum += (w[2] & 0xf0) * x_thread[4];
accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
accum += (w[3] & 0x3e) * x_thread[5];
accum += (w[3] & 0xc0) * x_thread[6];
accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
accum += (w[4] & 0xf8) * x_thread[7];
}
}
else if (bits == 6) {
for (int i = 0; i < (values_per_thread / 4); i++) {
x_thread += 4 * i;
@@ -234,8 +298,9 @@ inline U qdot_safe(
U sum,
int N) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
U accum = 0;
@@ -280,6 +345,26 @@ inline U qdot_safe(
}
}
else if (bits == 5) {
for (int i = 0; i < (N / 8); i++) {
x_thread += 8 * i;
w += 5 * i;
accum += (w[0] & 0x1f) * x_thread[0];
accum += (w[0] & 0xe0) * x_thread[1];
accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
accum += (w[1] & 0x7c) * x_thread[2];
accum += (w[1] & 0x80) * x_thread[3];
accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
accum += (w[2] & 0xf0) * x_thread[4];
accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
accum += (w[3] & 0x3e) * x_thread[5];
accum += (w[3] & 0xc0) * x_thread[6];
accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
accum += (w[4] & 0xf8) * x_thread[7];
}
}
else if (bits == 6) {
for (int i = 0; i < (N / 4); i++) {
x_thread += 4 * i;
@@ -310,8 +395,9 @@ template <typename U, int values_per_thread, int bits>
inline void
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
if (bits == 2) {
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
@@ -348,8 +434,31 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
}
}
} else if (bits == 6) {
else if (bits == 5) {
for (int i = 0; i < (values_per_thread / 8); i++) {
uint8_t w0 = w[5 * i];
uint8_t w1 = w[5 * i + 1];
uint8_t w2 = w[5 * i + 2];
uint8_t w3 = w[5 * i + 3];
uint8_t w4 = w[5 * i + 4];
result[8 * i] += x * ((w0 & 0x1f) * scale + bias);
result[8 * i + 1] +=
x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias);
result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias);
result[8 * i + 3] +=
x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias);
result[8 * i + 4] +=
x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias);
result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias);
result[8 * i + 6] +=
x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias);
result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias);
}
}
else if (bits == 6) {
for (int i = 0; i < (values_per_thread / 4); i++) {
uint8_t w0 = w[3 * i];
uint8_t w1 = w[3 * i + 1];
@@ -375,8 +484,9 @@ template <typename U, int N, int bits>
inline void
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
if (bits == 2) {
U s[4] = {
@@ -416,11 +526,26 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
}
}
else if (bits == 5) {
for (int i = 0; i < (N / 8); i++) {
w_local += 8 * i;
w += 5 * i;
w_local[0] = (w[0] & 0x1f) * scale + bias;
w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
}
}
else if (bits == 6) {
for (int i = 0; i < (N / 4); i++) {
w_local += 4 * i;
w += 3 * i;
w_local[0] = (w[0] & 0x3f) * scale + bias;
w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
@@ -452,11 +577,12 @@ struct QuantizedBlockLoader {
group_size % BCOLS == 0,
"The group size should be divisible by the columns");
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>();
MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>();
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
MLX_MTL_CONST short n_reads =
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
@@ -632,12 +758,11 @@ METAL_FUNC void qmv_fast_impl(
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int pack_factor = get_pack_factor<bits, 32>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
@@ -700,12 +825,12 @@ METAL_FUNC void qmv_impl(
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int packs_per_thread = 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int pack_factor = get_pack_factor<bits, 32>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
@@ -857,8 +982,9 @@ METAL_FUNC void qvm_impl(
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int num_simdgroups = 2;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
constexpr int pack_factor = get_pack_factor<bits, 32>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
constexpr int tn = 32 / pack_factor;
constexpr int block_size = SIMD_SIZE;
@@ -981,9 +1107,10 @@ METAL_FUNC void qmm_t_impl(
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::
@@ -1106,11 +1233,11 @@ METAL_FUNC void qmm_n_impl(
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::
@@ -2120,11 +2247,10 @@ template <
uint3 tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
using mma_t = mlx::steel::BlockMMA<
T,
@@ -2305,13 +2431,13 @@ template <typename T, const int group_size, const int bits>
constexpr float eps = 1e-7;
constexpr int simd_size = 32;
constexpr float n_bins = (1 << bits) - 1;
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
constexpr int values_per_reduce = group_size / simd_size;
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
constexpr int writes_per_reduce = pack_factor / values_per_reduce;
constexpr int writes_per_pack =
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
static_assert(
group_size % simd_size == 0,
@@ -2354,8 +2480,8 @@ template <typename T, const int group_size, const int bits>
biases[gindex] = static_cast<T>(bias);
}
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
uint32_t output = 0;
using OutType = metal::conditional_t<bits == 5, uint64_t, uint32_t>;
OutType output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
@@ -2363,27 +2489,35 @@ template <typename T, const int group_size, const int bits>
if (bits == 8) {
output = val;
} else {
output += val << (bits * (i % packs_per_int));
output |= val << (bits * (i % pack_factor));
}
if (packs_per_int < values_per_reduce &&
i % packs_per_int == packs_per_int - 1) {
out[out_index + i / packs_per_int] = output;
if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) {
out[out_index + i / pack_factor] = output;
output = 0;
} else {
#pragma clang loop unroll(full)
for (int j = 1; j < writes_per_reduce; j++) {
uint8_t sval = simd_shuffle_down(val, j);
output += sval << (bits * (j * values_per_reduce + i));
output |= static_cast<OutType>(sval)
<< (bits * (j * values_per_reduce + i));
}
}
}
if (bits == 3 || bits == 6) {
if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
out[out_index] = output & 0xff;
out[out_index + 1] = (output & 0xff00) >> 8;
out[out_index + 2] = (output & 0xff0000) >> 16;
}
} else if (bits == 5) {
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
out[out_index] = output & 0xff;
out[out_index + 1] = (output & 0xff00) >> 8;
out[out_index + 2] = (output & 0xff0000) >> 16;
out[out_index + 3] = (output & 0xff000000) >> 24;
out[out_index + 4] = (output & 0xff00000000) >> 32;
}
} else {
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
out[out_index / writes_per_reduce] = output;
@@ -2399,12 +2533,11 @@ template <typename T, const int group_size, const int bits>
device T* out [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
size_t offset = index.x + grid_dim.x * size_t(index.y);
size_t oindex = offset * packs_per_int;
size_t oindex = offset * pack_factor;
size_t gindex = oindex / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
@@ -2421,7 +2554,16 @@ template <typename T, const int group_size, const int bits>
out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
} else if (bits == 5) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x1f) * scale + bias;
out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
out[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
out[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
out[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
} else if (bits == 6) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x3f) * scale + bias;
@@ -2431,7 +2573,7 @@ template <typename T, const int group_size, const int bits>
} else {
uint val = w[offset];
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
for (int i = 0; i < pack_factor; i++) {
uint8_t d;
if (bits == 2) {
d = (val >> (bits * i)) & 0x03;

View File

@@ -136,6 +136,7 @@
instantiate_quantized_groups(2) \
instantiate_quantized_groups(3) \
instantiate_quantized_groups(4) \
instantiate_quantized_groups(5) \
instantiate_quantized_groups(6) \
instantiate_quantized_groups(8)

View File

@@ -224,7 +224,7 @@ template <
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
// Simple loop over non_row_reductions and reduce the row in the thread.
IdxT out_idx = tid.x + tsize.y * IdxT(tid.y);
IdxT out_idx = tid.x + tsize.x * IdxT(tid.y);
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
for (uint r = 0; r < non_row_reductions; r++) {

View File

@@ -128,8 +128,8 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
vals[i] =
(offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
}
}
prevmax = maxval;

View File

@@ -381,6 +381,7 @@ struct Conv2DWeightBlockLoader {
const constant MLXConvParams<2>* params;
int weight_hw;
int weight_step;
const int read_n;
const bool do_read;
@@ -402,6 +403,7 @@ struct Conv2DWeightBlockLoader {
src(src_ + bi * src_ld + bj),
params(params_),
weight_hw(0),
weight_step(params->C / params->groups),
read_n(offsets.y + bi),
do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
@@ -435,15 +437,15 @@ struct Conv2DWeightBlockLoader {
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_hw < (params->wS[1] * params->wS[0])) {
src += params->wt_strides[2];
src += weight_step;
return;
}
weight_hw = 0;
src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2];
src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step;
}
};
} // namespace steel
} // namespace mlx
} // namespace mlx

View File

@@ -272,7 +272,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
return;
}
const device T* curr_src = src + weight_hw * params->wt_strides[2];
const device T* curr_src = src + weight_hw * (params->C / params->groups);
if (BN != 8 || do_read) {
STEEL_PRAGMA_UNROLL
@@ -316,4 +316,4 @@ struct Conv2DWeightBlockLoaderSmallChannels {
};
} // namespace steel
} // namespace mlx
} // namespace mlx

View File

@@ -9,8 +9,14 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
}
}
}
@@ -23,9 +29,15 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
}
}
}

View File

@@ -8,8 +8,8 @@
#include "mlx/backend/metal/kernels/ternary_ops.h"
#include "mlx/backend/metal/kernels/ternary.h"
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
#define instantiate_ternary_base(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
@@ -20,19 +20,23 @@
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \
instantiate_ternary_base(op, tname, type)
#define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \
instantiate_ternary_all(op, uint8, uint8_t) \
instantiate_ternary_all(op, uint16, uint16_t) \
instantiate_ternary_all(op, uint32, uint32_t) \
instantiate_ternary_all(op, uint64, uint64_t) \
instantiate_ternary_base(op, uint64, uint64_t) \
instantiate_ternary_all(op, int8, int8_t) \
instantiate_ternary_all(op, int16, int16_t) \
instantiate_ternary_all(op, int32, int32_t) \
instantiate_ternary_all(op, int64, int64_t) \
instantiate_ternary_base(op, int64, int64_t) \
instantiate_ternary_all(op, float16, half) \
instantiate_ternary_all(op, float32, float) \
instantiate_ternary_all(op, bfloat16, bfloat16_t) \
instantiate_ternary_all(op, complex64, complex64_t) // clang-format on
instantiate_ternary_base(op, complex64, complex64_t) // clang-format on
instantiate_ternary_types(Select)

View File

@@ -7,8 +7,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
out[index + i] = Op()(in[index + i]);
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
out[index + i] = Op()(in[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
out[index + i] = Op()(in[index + i]);
}
}
}
@@ -19,9 +25,15 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
out[offset + i] = Op()(in[offset + i]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
out[offset + i] = Op()(in[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
out[offset + i] = Op()(in[offset + i]);
}
}
}

View File

@@ -5,31 +5,41 @@
#include "mlx/backend/metal/kernels/unary_ops.h"
#include "mlx/backend/metal/kernels/unary.h"
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel( \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
instantiate_kernel( \
#define instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("vn_" #op #in_tname #out_tname, unary_v, in_type, out_type, op)
#define instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op, 1) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel( \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
instantiate_kernel( \
"gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \
instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type)
#define instantiate_unary_all_same(op, tname, type) \
instantiate_unary_all(op, tname, tname, type, type)
#define instantiate_unary_base_same(op, tname, type) \
instantiate_unary_base(op, tname, tname, type, type)
#define instantiate_unary_float(op) \
instantiate_unary_all_same(op, float16, half) \
instantiate_unary_all_same(op, float32, float) \
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
#define instantiate_unary_int(op) \
instantiate_unary_all_same(op, uint8, uint8_t) \
instantiate_unary_all_same(op, uint16, uint16_t) \
instantiate_unary_all_same(op, uint32, uint32_t) \
instantiate_unary_all_same(op, uint64, uint64_t) \
instantiate_unary_all_same(op, int8, int8_t) \
instantiate_unary_all_same(op, int16, int16_t) \
instantiate_unary_all_same(op, int32, int32_t) \
instantiate_unary_all_same(op, int64, int64_t)
#define instantiate_unary_int(op) \
instantiate_unary_all_same(op, uint8, uint8_t) \
instantiate_unary_all_same(op, uint16, uint16_t) \
instantiate_unary_all_same(op, uint32, uint32_t) \
instantiate_unary_base_same(op, uint64, uint64_t) \
instantiate_unary_all_same(op, int8, int8_t) \
instantiate_unary_all_same(op, int16, int16_t) \
instantiate_unary_all_same(op, int32, int32_t) \
instantiate_unary_base_same(op, int64, int64_t)
#define instantiate_unary_types(op) \
instantiate_unary_all_same(op, bool_, bool) \
@@ -68,29 +78,29 @@ instantiate_unary_float(Tanh)
instantiate_unary_float(Round)
instantiate_unary_int(BitwiseInvert)
instantiate_unary_all_same(Abs, complex64, complex64_t)
instantiate_unary_all_same(ArcCos, complex64, complex64_t)
instantiate_unary_all_same(ArcSin, complex64, complex64_t)
instantiate_unary_all_same(ArcTan, complex64, complex64_t)
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log1p, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t)
instantiate_unary_all_same(Sinh, complex64, complex64_t)
instantiate_unary_all_same(Square, complex64, complex64_t)
instantiate_unary_all_same(Sqrt, complex64, complex64_t)
instantiate_unary_all_same(Rsqrt, complex64, complex64_t)
instantiate_unary_all_same(Tan, complex64, complex64_t)
instantiate_unary_all_same(Tanh, complex64, complex64_t)
instantiate_unary_all_same(Round, complex64, complex64_t)
instantiate_unary_all(Real, complex64, float32, complex64_t, float)
instantiate_unary_all(Imag, complex64, float32, complex64_t, float)
instantiate_unary_base_same(Abs, complex64, complex64_t)
instantiate_unary_base_same(ArcCos, complex64, complex64_t)
instantiate_unary_base_same(ArcSin, complex64, complex64_t)
instantiate_unary_base_same(ArcTan, complex64, complex64_t)
instantiate_unary_base_same(Conjugate, complex64, complex64_t)
instantiate_unary_base_same(Cos, complex64, complex64_t)
instantiate_unary_base_same(Cosh, complex64, complex64_t)
instantiate_unary_base_same(Exp, complex64, complex64_t)
instantiate_unary_base_same(Log, complex64, complex64_t)
instantiate_unary_base_same(Log1p, complex64, complex64_t)
instantiate_unary_base_same(Log2, complex64, complex64_t)
instantiate_unary_base_same(Log10, complex64, complex64_t)
instantiate_unary_base_same(Negative, complex64, complex64_t)
instantiate_unary_base_same(Sign, complex64, complex64_t)
instantiate_unary_base_same(Sin, complex64, complex64_t)
instantiate_unary_base_same(Sinh, complex64, complex64_t)
instantiate_unary_base_same(Square, complex64, complex64_t)
instantiate_unary_base_same(Sqrt, complex64, complex64_t)
instantiate_unary_base_same(Rsqrt, complex64, complex64_t)
instantiate_unary_base_same(Tan, complex64, complex64_t)
instantiate_unary_base_same(Tanh, complex64, complex64_t)
instantiate_unary_base_same(Round, complex64, complex64_t)
instantiate_unary_base(Real, complex64, float32, complex64_t, float)
instantiate_unary_base(Imag, complex64, float32, complex64_t, float)
instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on

View File

@@ -6,7 +6,7 @@
#include <sstream>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/common/matmul.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
@@ -21,69 +21,6 @@ namespace mlx::core {
namespace {
inline auto collapse_batches(const array& a, const array& b) {
// Get and check the shape for the batched dims
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
if (A_bshape != B_bshape) {
std::ostringstream msg;
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ".";
throw std::runtime_error(msg.str());
}
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];
if (batch_shape.empty()) {
batch_shape.push_back(1);
A_batch_stride.push_back(0);
B_batch_stride.push_back(0);
}
return std::make_tuple(batch_shape, A_batch_stride, B_batch_stride);
}
inline auto collapse_batches(const array& a, const array& b, const array& c) {
// Get and check the shape for the batched dims
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
if (A_bshape != B_bshape || A_bshape != C_bshape) {
std::ostringstream msg;
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
throw std::runtime_error(msg.str());
}
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];
auto C_batch_stride = batch_strides[2];
if (batch_shape.empty()) {
batch_shape.push_back(1);
A_batch_stride.push_back(0);
B_batch_stride.push_back(0);
C_batch_stride.push_back(0);
}
return std::make_tuple(
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
}
std::tuple<bool, int64_t, array> check_transpose(
std::vector<array>& copies,
const Stream& s,
@@ -716,6 +653,23 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
// Return 0s if either input is empty
if (out.size() == 0) {
out.set_data(allocator::malloc(out.nbytes()));
return;
}
// Copy c into out and return
if (inputs[0].shape(-1) == 0) {
copy_gpu(
inputs[2],
out,
inputs[2].flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);

View File

@@ -146,7 +146,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
int,
int,
int) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
@@ -207,7 +207,7 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel(
@@ -259,7 +259,7 @@ MTL::ComputePipelineState* get_fft_kernel(
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const std::string&) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_quantized_kernel(
@@ -283,7 +283,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
return d.get_kernel(kernel_name, hash_name, func_consts);
}
} // namespace mlx::core

View File

@@ -10,6 +10,10 @@
namespace mlx::core::fast {
bool RMSNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void RMSNorm::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@@ -168,7 +172,7 @@ void RMSNormVJP::eval_gpu(
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(op_name, hash_name, func_consts);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
@@ -207,6 +211,10 @@ void RMSNormVJP::eval_gpu(
}
}
bool LayerNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void LayerNorm::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@@ -247,12 +255,13 @@ void LayerNorm::eval_gpu(
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
const int simd_size = 32;
const int n_reads = RMS_N_READS;
const int looped_limit = RMS_LOOPED_LIMIT;
int simd_size = 32;
int n_reads = 8;
int looped_limit = 6656;
std::string op_name = "layer_norm";
if (axis_size > looped_limit) {
op_name += "_looped";
n_reads = 4;
}
op_name += type_to_name(out);
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -264,7 +273,13 @@ void LayerNorm::eval_gpu(
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) {
std::ostringstream msg;
msg << "[layer_norm] Threadgroup size " << threadgroup_size
<< " is larger than the maximum allowed threadgroup size "
<< kernel->maxTotalThreadsPerThreadgroup();
throw std::runtime_error(msg.str());
}
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
@@ -364,12 +379,13 @@ void LayerNormVJP::eval_gpu(
g, gb, "sum", plan, {0}, compute_encoder, d, s);
}
const int simd_size = 32;
const int n_reads = RMS_N_READS;
const int looped_limit = RMS_LOOPED_LIMIT;
int simd_size = 32;
int n_reads = 8;
int looped_limit = 8192;
std::string op_name = "vjp_layer_norm";
if (axis_size > looped_limit) {
op_name += "_looped";
n_reads = 4;
}
op_name += type_to_name(gx);
@@ -379,14 +395,20 @@ void LayerNormVJP::eval_gpu(
};
{
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(op_name, hash_name, func_consts);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) {
std::ostringstream msg;
msg << "[vjp_layer_norm] Threadgroup size " << threadgroup_size
<< " is larger than the maximum allowed threadgroup size "
<< kernel->maxTotalThreadsPerThreadgroup();
throw std::runtime_error(msg.str());
}
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);

View File

@@ -182,8 +182,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = out.size() * thread_group_size;
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
auto gd = get_2d_grid_dims(out.shape(), out.strides());
MTL::Size grid_dims = MTL::Size(thread_group_size, gd.width, gd.height);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
@@ -378,10 +378,16 @@ void Cholesky::eval_gpu(const std::vector<array>& inputs, array& out) {
"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.");
}
void Eig::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI.");
}
void Eigh::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI.");
throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI.");
}
void LUF::eval_gpu(

View File

@@ -976,7 +976,9 @@ void fast::AffineQuantize::eval_gpu(
// Treat uint32 as uint8 in kernel
constexpr int uint8_per_uint32 = 4;
constexpr int simd_size = 32;
int packs_per_int = bits_ == 3 ? 8 : bits_ == 6 ? 4 : 8 / bits_;
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
: bits_ == 6 ? 4
: 8 / bits_;
int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size;
size_t nthreads =
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;

View File

@@ -7,6 +7,10 @@ namespace mlx::core::fast {
constexpr int n_per_thread = 4;
bool RoPE::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void RoPE::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {

View File

@@ -4,10 +4,10 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
namespace mlx::core::fast {
@@ -73,7 +73,7 @@ void sdpa_full_self_attention_metal(
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
const int NQ = (qL + bq - 1) / bq;
@@ -180,7 +180,7 @@ void sdpa_vector(
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(kname, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments
@@ -281,7 +281,7 @@ void sdpa_vector_2pass(
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(kname, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -339,6 +339,46 @@ void sdpa_vector_2pass(
} // namespace
bool ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
const bool supports_sdpa_full = query_sequence_length > 8 &&
sdpa_full_supported_mask && sdpa_full_supported_head_dim;
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
(query_sequence_length <= key_sequence_length) &&
sdpa_vector_supported_head_dim;
return !(supports_sdpa_full || supports_sdpa_vector);
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {

View File

@@ -45,7 +45,7 @@ void ternary_op_gpu_inplace(
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > INT32_MAX;
work_per_thread = get_work_per_thread(b.dtype());
work_per_thread = get_work_per_thread(b.dtype(), out.data_size());
}
std::string kernel_name;
if (topt == TernaryOpType::General) {
@@ -60,6 +60,8 @@ void ternary_op_gpu_inplace(
}
} else if (large) {
kernel_name = "v2";
} else if (work_per_thread > 1) {
kernel_name = "vn";
} else {
kernel_name = "v";
}

View File

@@ -1,5 +1,6 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/common/unary.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
@@ -43,8 +44,8 @@ void unary_op_gpu_inplace(
int work_per_thread;
std::string kernel_name;
if (contig) {
work_per_thread = get_work_per_thread(in.dtype());
kernel_name = (large ? "v2" : "v");
work_per_thread = get_work_per_thread(in.dtype(), in.data_size());
kernel_name = (large ? "v2" : (work_per_thread > 1 ? "vn" : "v"));
} else {
work_per_thread = large ? 4 : 1;
kernel_name = "gn" + std::to_string(work_per_thread);
@@ -99,21 +100,7 @@ void unary_op_gpu(
array& out,
const std::string op,
const Stream& s) {
auto& in = inputs[0];
bool contig = in.flags().contiguous;
if (contig) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace(inputs, out, op, s);
}

View File

@@ -1,8 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/utils.h"
using namespace mlx;
#include "mlx/backend/common/utils.h"
namespace mlx::core {
@@ -59,109 +58,20 @@ std::string type_to_name(const array& a) {
return type_to_name(a.dtype());
}
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
int pows[3] = {0, 0, 0};
int sum = 0;
while (true) {
int presum = sum;
// Check all the pows
if (dim0 >= (1 << (pows[0] + 1))) {
pows[0]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim1 >= (1 << (pows[1] + 1))) {
pows[1]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim2 >= (1 << (pows[2] + 1))) {
pows[2]++;
sum++;
}
if (sum == presum || sum == pow2) {
break;
}
}
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2) {
Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2);
return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
}
MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides) {
// Dims with strides of 0 are ignored as they
// correspond to broadcasted dimensions
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape.");
}
if (grid_y > grid_x) {
std::swap(grid_x, grid_y);
}
return MTL::Size(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
Dims dims = get_2d_grid_dims_common(shape, strides);
return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
}
MTL::Size
get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor) {
// Compute the 2d grid dimensions such that the total size of the grid is
// divided by divisor.
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
// No need to add this shape we can just remove it from the divisor.
if (divisor % shape[i] == 0) {
divisor /= shape[i];
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
if (divisor > 1) {
if (grid_x % divisor == 0) {
grid_x /= divisor;
divisor = 1;
} else if (grid_y % divisor == 0) {
grid_y /= divisor;
divisor = 1;
}
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
throw std::runtime_error("Unable to safely factor shape.");
}
if (grid_y > grid_x) {
std::swap(grid_x, grid_y);
}
return MTL::Size(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
std::string get_primitive_string(Primitive* primitive) {
std::ostringstream op_t;
primitive->print(op_t);
return op_t.str();
Dims dims = get_2d_grid_dims_common(shape, strides, divisor);
return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
}
} // namespace mlx::core

View File

@@ -13,22 +13,9 @@ namespace mlx::core {
std::string type_to_name(const Dtype& t);
std::string type_to_name(const array& a);
// Compute the thread block dimensions which fit the given
// input dimensions.
// - The thread block dimensions will be powers of two
// - The thread block size will be less than 2^pow2
// Compute the grid and block dimensions, check backend/common/utils.h for docs.
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
// Computes a 2D grid where each element is < UINT_MAX
// Assumes:
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
// - shape and strides correspond to a contiguous (no holes) but
// possibly broadcasted array
MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides);
// Same as above but we do an implicit division with divisor.
// Basically, equivalent to factorizing
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
MTL::Size
get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor);
@@ -58,8 +45,6 @@ inline void debug_set_primitive_buffer_label(
#endif
}
std::string get_primitive_string(Primitive* primitive);
template <typename T>
constexpr bool is_numeric_except_char = std::is_arithmetic_v<T> &&
!std::is_same_v<T, char> && !std::is_same_v<T, signed char> &&
@@ -87,6 +72,10 @@ void concatenate(std::string& acc, T first, Args... args) {
inline int get_work_per_thread(Dtype dtype) {
return std::max(1, 8 / dtype.size());
}
inline int get_work_per_thread(Dtype dtype, size_t size) {
constexpr size_t wpt_threshold = 1 << 16;
return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size());
}
inline size_t ceildiv(size_t n, size_t m) {
return (n + m - 1) / m;

View File

@@ -55,6 +55,7 @@ NO_CPU(DynamicSlice)
NO_CPU(DynamicSliceUpdate)
NO_CPU(NumberOfElements)
NO_CPU(Remainder)
NO_CPU_MULTI(Eig)
NO_CPU_MULTI(Eigh)
NO_CPU(Equal)
NO_CPU(Erf)

View File

@@ -2,6 +2,7 @@
#include "mlx/primitives.h"
#include "mlx/distributed/primitives.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#define NO_GPU_MULTI(func) \
@@ -10,6 +11,12 @@
throw std::runtime_error(#func " has no GPU implementation."); \
}
#define NO_GPU_USE_FALLBACK(func) \
bool func::use_fallback(Stream s) { \
return true; \
} \
NO_GPU_MULTI(func)
#define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no GPU implementation."); \
@@ -17,6 +24,17 @@
namespace mlx::core {
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
NO_GPU(Abs)
NO_GPU(Add)
NO_GPU(AddMM)
@@ -126,17 +144,30 @@ NO_GPU(Unflatten)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eigh)
NO_GPU_MULTI(Eig)
NO_GPU(View)
namespace fast {
NO_GPU_MULTI(LayerNorm)
NO_GPU_USE_FALLBACK(LayerNorm)
NO_GPU_MULTI(LayerNormVJP)
NO_GPU_MULTI(RMSNorm)
NO_GPU_USE_FALLBACK(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE)
NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)
MetalKernelFunction metal_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool ensure_row_contiguous,
bool atomic_outputs) {
throw std::runtime_error("[metal_kernel] No GPU back-end.");
}
} // namespace fast
namespace distributed {

View File

@@ -1,16 +1,20 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <map>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include "mlx/allocator.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/compile.h"
#include "mlx/compile_impl.h"
#include "mlx/fast_primitives.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
namespace mlx::core {
@@ -82,7 +86,54 @@ Compiled::Compiled(
inputs_(std::move(inputs)),
outputs_(std::move(outputs)),
tape_(std::move(tape)),
constant_ids_(std::move(constant_ids)) {}
constant_ids_(std::move(constant_ids)),
is_constant_([this](size_t i) {
return constant_ids_.find(inputs_[i].id()) != constant_ids_.end();
}) {
// Build the kernel name.
NodeNamer namer;
std::ostringstream os;
std::ostringstream constant_hasher;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (const auto& x : inputs_) {
namer.get_name(x);
}
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (const auto& a : tape_) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
}
os << "_";
for (const auto& x : inputs_) {
if (constant_ids_.find(x.id()) != constant_ids_.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
for (const auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
kernel_lib_ = os.str();
}
std::vector<array> Compiled::vjp(
const std::vector<array>&,

View File

@@ -266,6 +266,7 @@ struct PrimitiveFactory {
SERIALIZE_PRIMITIVE(Floor),
SERIALIZE_PRIMITIVE(Full),
SERIALIZE_PRIMITIVE(Gather),
SERIALIZE_PRIMITIVE(GatherAxis),
SERIALIZE_PRIMITIVE(GatherMM),
SERIALIZE_PRIMITIVE(Greater),
SERIALIZE_PRIMITIVE(GreaterEqual),
@@ -307,6 +308,7 @@ struct PrimitiveFactory {
"CumMax",
"CumLogaddexp"),
SERIALIZE_PRIMITIVE(Scatter),
SERIALIZE_PRIMITIVE(ScatterAxis),
SERIALIZE_PRIMITIVE(Select),
SERIALIZE_PRIMITIVE(Sigmoid),
SERIALIZE_PRIMITIVE(Sign),
@@ -331,6 +333,7 @@ struct PrimitiveFactory {
SERIALIZE_PRIMITIVE(SVD),
SERIALIZE_PRIMITIVE(Inverse),
SERIALIZE_PRIMITIVE(Cholesky),
SERIALIZE_PRIMITIVE(Eig),
SERIALIZE_PRIMITIVE(Eigh),
SERIALIZE_PRIMITIVE(AffineQuantize),
SERIALIZE_PRIMITIVE(RMSNorm),

Some files were not shown because too many files have changed in this diff Show More