Compare commits

..

16 Commits

Author SHA1 Message Date
Cheng
99c33d011d rebase + nit (#2260)
Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-10 10:51:51 -07:00
Awni Hannun
62fecf3e13 fix conv export (#2265) 2025-06-10 09:34:01 -07:00
Cheng
7c4eb5d03e CUDA backend: random (#2261) 2025-06-10 08:59:56 -07:00
Cheng
bae9a6b404 CUDA backend: sort (#2262)
Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-10 08:59:47 -07:00
Christopher Fleetwood
004c1d8ef2 Report number of missing parameters (#2264)
* chore: inform

* chore: format

---------

Co-authored-by: FL33TW00D <FL33TW00D@users.noreply.github.com>
2025-06-10 06:37:50 -07:00
Cheng
7ebb2e0193 CUDA backend: binary ops (#2259) 2025-06-10 06:37:40 -07:00
Awni Hannun
9ce77798b1 fix export to work with gather/scatter axis (#2263) 2025-06-09 20:37:27 -07:00
Cheng
f8bad60609 CUDA backend: unary ops (#2158) 2025-06-09 06:45:08 -07:00
Emmanuel Ferdman
5866b3857b Refactor the lu test (#2250)
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-06-07 06:12:08 -07:00
Awni Hannun
1ca616844b Fix unintuitive metal kernel caching (#2242)
* Fix unintuitive metal kernel caching

* alternative solution
2025-06-06 20:08:15 -07:00
Angelos Katharopoulos
2e8cf0b450 Change layernorms to two pass algorithm (#2246) 2025-06-06 13:34:56 -07:00
Cheng
24f89173d1 CUDA backend: matmul (#2241) 2025-06-06 12:24:04 -07:00
Awni Hannun
c6a20b427a Improve metal elementwise kernels (#2247)
* improve metal elementwise kernels

* compile and copy

* fix jit
2025-06-06 11:37:40 -07:00
Awni Hannun
a5ac9244c4 fix linux linking error (#2248) 2025-06-06 10:41:51 -07:00
Awni Hannun
c763fe1be0 default strict mode for module update and update_modules (#2239) 2025-06-05 15:27:02 -07:00
Cheng
52dc8c8cd5 Add profiler annotations in common primitives for CUDA backend (#2244) 2025-06-04 19:55:12 -07:00
70 changed files with 4927 additions and 1308 deletions

View File

@@ -1,5 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from functools import partial
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
@@ -18,51 +20,63 @@ def layer_norm(x, w, b, eps):
return y
def time_layer_norm():
def time_layer_norm(N, dt):
L = 1024
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
mx.eval(x, w, b, y)
def layer_norm_loop(g, x, w, b):
def layer_norm_loop(f, x, w, b):
for _ in range(32):
x = f(x, w, b)
return x
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
def layer_norm_grad_loop(g, x, w, b):
gx, gw, gb = x, w, b
for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb
time_fn(layer_norm_loop, g1, x, w, b)
time_fn(layer_norm_loop, g2, x, w, b)
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
time_fn(layer_norm_grad_loop, g1, x, w, b)
time_fn(layer_norm_grad_loop, g2, x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
mx.eval(x, w, b, y)
def layer_norm_loop(g, x):
def layer_norm_grad_x_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(layer_norm_loop, g1, x)
time_fn(layer_norm_loop, g2, x)
time_fn(layer_norm_loop, mx.compile(g1), x)
time_fn(layer_norm_loop, mx.compile(g2), x)
time_fn(layer_norm_grad_x_loop, g1, x)
time_fn(layer_norm_grad_x_loop, g2, x)
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
if __name__ == "__main__":
time_layer_norm()
for dt in [mx.float32, mx.float16, mx.bfloat16]:
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
print(dt, n)
time_layer_norm(n, dt)

View File

@@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example
--------------
.. currentmodule:: mlx.core
Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
def exp_elementwise(a: mx.array):
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
@@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Every time you make a kernel, a new Metal library is created and possibly
JIT compiled. To reduce the overhead from that, build the kernel once with
:func:`fast.metal_kernel` and then use it many times.
.. note::
We are only required to pass the body of the Metal kernel in ``source``.
Only pass the body of the Metal kernel in ``source``. The function
signature is generated automatically.
The full function signature will be generated using:
@@ -78,44 +86,51 @@ Putting this all together, the generated function signature for ``myexp`` is as
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
``threadgroup`` size threadgroups. For optimal performance, each thread group
dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
generated code for debugging purposes.
Using Shape/Strides
-------------------
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
when indexing.
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
is ``True`` by default. This will copy the array inputs if needed
before the kernel is launched to ensure that the memory layout is row
contiguous. Generally this makes writing the kernel easier, since we don't
have to worry about gaps or the ordering of the dims when indexing.
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
input array ``a`` if any are present in ``source``.
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
present in ``source``. We can then use MLX's built in indexing utils to fetch
the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
Let's convert ``myexp`` above to support arbitrarily strided arrays without
relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
@@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops:
.. code-block:: python
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output
return output
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel:
.. code-block:: python
@mx.custom_function
def grid_sample(x, grid):
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
assert D == 2, "Last dim of `grid` must be size 2."
int ix_nw = floor(ix);
int iy_nw = floor(iy);
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
@mx.custom_function
def grid_sample(x, grid):
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs[0]
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
assert D == 2, "Last dim of `grid` must be size 2."
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs[0]
For a reasonably sized input such as:
.. code-block:: python
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement:
@@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
Grid Sample VJP
---------------
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
its custom vjp transform so MLX can differentiate it.
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
define its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra ``mx.fast.metal_kernel`` features:
requires a few extra :func:`fast.metal_kernel` features:
* ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
@@ -299,128 +316,129 @@ We can then implement the backwards pass as follows:
.. code-block:: python
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
assert D == 2, "Last dim of `grid` must be size 2."
int gH = grid_shape[1];
int gW = grid_shape[2];
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
int gH = grid_shape[1];
int gW = grid_shape[2];
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T gix_mult = W / 2;
T giy_mult = H / 2;
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
T gix_mult = W / 2;
T giy_mult = H / 2;
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
assert D == 2, "Last dim of `grid` must be size 2."
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
There's an even larger speed up for the vjp:

View File

@@ -397,11 +397,11 @@ below.
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Load the metal library
auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -172,11 +172,11 @@ void Axpby::eval_gpu(
kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Load the metal library
auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -2,7 +2,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
@@ -26,7 +26,7 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
return true;
} else {

View File

@@ -0,0 +1,78 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/common/utils.h"
#include "mlx/utils.h"
#include <sstream>
namespace mlx::core {
inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a,
const array& b) {
// Get and check the shape for the batched dims
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
if (A_bshape != B_bshape) {
std::ostringstream msg;
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ".";
throw std::runtime_error(msg.str());
}
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
auto a_batch_strides = batch_strides[0];
auto b_batch_strides = batch_strides[1];
if (batch_shape.empty()) {
batch_shape.push_back(1);
a_batch_strides.push_back(0);
b_batch_strides.push_back(0);
}
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
}
inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) {
// Get and check the shape for the batched dims
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
if (A_bshape != B_bshape || A_bshape != C_bshape) {
std::ostringstream msg;
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
throw std::runtime_error(msg.str());
}
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];
auto C_batch_stride = batch_strides[2];
if (batch_shape.empty()) {
batch_shape.push_back(1);
A_batch_stride.push_back(0);
B_batch_stride.push_back(0);
C_batch_stride.push_back(0);
}
return std::make_tuple(
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
}
} // namespace mlx::core

View File

@@ -0,0 +1,26 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
inline void set_unary_output_data(const array& in, array& out) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
}
} // namespace mlx::core

View File

@@ -2,32 +2,13 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/utils.h"
namespace mlx::core {
void set_unary_output_data(const array& in, array& out) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
}
template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {

View File

@@ -6,21 +6,41 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
# Enable defining device lambda functions.
target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
# Explicitly pass this flag to suppress the warning, it is safe to set it to
# true but the warning wouldn't be suppressed.
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
target_compile_options(
mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--static-global-template-stub=false>")
endif()
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES
@@ -51,6 +71,9 @@ target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
find_package(CUDAToolkit REQUIRED)
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
# Use cublasLt.
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)

305
mlx/backend/cuda/binary.cu Normal file
View File

@@ -0,0 +1,305 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/binary_ops.cuh"
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[0], b[0]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[0], b[index]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[index], b[0]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[index], b[index]);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_g_nd(
const In* a,
const In* b,
Out* out,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
out[index] = Op{}(a[a_idx], b[b_idx]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g(
const In* a,
const In* b,
Out* out,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
out[index] = Op{}(a[a_idx], b[b_idx]);
}
}
template <typename Op, typename In, typename Out>
constexpr bool supports_binary_op() {
if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||
std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||
std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||
std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {
return std::is_same_v<In, Out>;
}
if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||
std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||
std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {
return std::is_same_v<Out, bool>;
}
if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, NaNEqual>) {
return std::is_same_v<Out, bool> &&
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
}
if (std::is_same_v<Op, LogAddExp> || std::is_same_v<Op, ArcTan2>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
std::is_same_v<Op, BitwiseXor>) {
return std::is_same_v<In, Out> && std::is_integral_v<In>;
}
if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
}
return false;
}
} // namespace cu
template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
auto& out = outputs[0];
if (out.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
bool large = a.data_size() > UINT32_MAX ||
b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel =
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size(),
const_param<NDIM>(shape),
const_param<NDIM>(a_strides),
const_param<NDIM>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
} else {
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size());
});
}
} else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out.dtype())));
}
});
});
});
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
std::vector<array> outputs{out};
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
}
#define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = outputs[0].primitive().stream(); \
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
}
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Equal)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, op, s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, op, s);
break;
}
}
} // namespace mlx::core

View File

@@ -1,26 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/copy.h"
namespace mlx::core {
void copy_gpu_inplace(
const array& in,
array& out,
const Shape& data_shape,
const Strides& strides_in_pre,
const Strides& strides_out_pre,
int64_t inp_offset,
int64_t out_offset,
CopyType ctype,
const Stream& s,
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend.");
}
void fill_gpu(const array& val, array& out, const Stream& s) {
throw std::runtime_error("fill_gpu not implemented in CUDA backend.");
}
} // namespace mlx::core

89
mlx/backend/cuda/copy.cu Normal file
View File

@@ -0,0 +1,89 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/copy/copy.cuh"
namespace mlx::core {
void copy_gpu_inplace(
const array& in_,
array& out,
const Shape& shape,
const Strides& strides_in,
const Strides& strides_out,
int64_t offset_in,
int64_t offset_out,
CopyType ctype,
const Stream& s,
const std::optional<array>& dynamic_offset_in,
const std::optional<array>& dynamic_offset_out) {
if (out.size() == 0) {
return;
}
const array& in = in_.data_shared_ptr() ? in_ : out;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
return;
}
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
auto [shape_collapsed, strides_vec] = collapse_contiguous_dims(
shape, std::vector{strides_in, strides_out}, INT32_MAX);
if (ctype == CopyType::General) {
copy_general_input(
encoder,
ctype,
in,
out,
offset_in,
offset_out,
shape_collapsed,
strides_vec[0]);
} else {
if (dynamic_offset_in || dynamic_offset_out) {
copy_general_dynamic(
encoder,
ctype,
in,
out,
offset_in,
offset_out,
shape_collapsed,
strides_vec[0],
strides_vec[1],
dynamic_offset_in ? *dynamic_offset_in : array(0, int64),
dynamic_offset_out ? *dynamic_offset_out : array(0, int64));
} else {
copy_general(
encoder,
ctype,
in,
out,
offset_in,
offset_out,
shape_collapsed,
strides_vec[0],
strides_vec[1]);
}
}
return;
}
}
void fill_gpu(const array& in, array& out, const Stream& s) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
}
} // namespace mlx::core

View File

@@ -0,0 +1,71 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/cast_op.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
namespace mlx::core {
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
using InType = cuda_type_t<CTYPE_IN>; \
using OutType = cuda_type_t<CTYPE_OUT>; \
if constexpr (cu::CastOp<InType, OutType>::is_castable) { \
__VA_ARGS__; \
} else { \
throw std::runtime_error(fmt::format( \
"Can not copy data from dtype {} to {}.", \
dtype_to_string(out.dtype()), \
dtype_to_string(in.dtype()))); \
} \
}); \
})
void copy_contiguous(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t offset_in,
int64_t offset_out);
void copy_general(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t offset_in,
int64_t offset_out,
const Shape& shape,
const Strides& strides_in,
const Strides& strides_out);
void copy_general_dynamic(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t offset_in,
int64_t offset_out,
const Shape& shape,
const Strides& strides_in,
const Strides& strides_out,
const array& dynamic_offset_in,
const array& dynamic_offset_out);
void copy_general_input(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t offset_in,
int64_t offset_out,
const Shape& shape,
const Strides& strides_in);
} // namespace mlx::core

View File

@@ -0,0 +1,56 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/copy/copy.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT>
__global__ void copy_s(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = CastOp<In, Out>{}(in[0]);
}
}
template <typename In, typename Out, typename IdxT>
__global__ void copy_v(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = CastOp<In, Out>{}(in[index]);
}
}
} // namespace cu
void copy_contiguous(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t in_offset,
int64_t out_offset) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::copy_s<InType, OutType, IdxT>;
if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,
out.data_size());
});
});
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,95 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/copy/copy.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM>
__global__ void copy_gg_nd(
const In* in,
Out* out,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
index, shape.data(), strides_in.data(), strides_out.data());
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
}
}
template <typename In, typename Out, typename IdxT>
__global__ void copy_gg(
const In* in,
Out* out,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides_in,
const __grid_constant__ Strides strides_out,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_4d(
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
}
}
} // namespace cu
void copy_general(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t offset_in,
int64_t offset_out,
const Shape& shape,
const Strides& strides_in,
const Strides& strides_out) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out));
});
} else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim);
}
});
});
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,105 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/copy/copy.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM>
__global__ void copy_gg_dynamic_nd(
const In* in,
Out* out,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out,
const int64_t* offset_in,
const int64_t* offset_out) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
index, shape.data(), strides_in.data(), strides_out.data());
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
}
}
template <typename In, typename Out, typename IdxT>
__global__ void copy_gg_dynamic(
const In* in,
Out* out,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides_in,
const __grid_constant__ Strides strides_out,
int ndim,
const int64_t* offset_in,
const int64_t* offset_out) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_4d(
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
}
}
} // namespace cu
void copy_general_dynamic(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t offset_in,
int64_t offset_out,
const Shape& shape,
const Strides& strides_in,
const Strides& strides_out,
const array& dynamic_offset_in,
const array& dynamic_offset_out) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_gg_dynamic_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out),
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
});
} else { // ndim >= 4
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
}
});
});
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,88 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/copy/copy.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM>
__global__ void copy_g_nd(
const In* in,
Out* out,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data());
out[index] = CastOp<In, Out>{}(in[idx_in]);
}
}
template <typename In, typename Out, typename IdxT>
__global__ void copy_g(
const In* in,
Out* out,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides_in,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim);
out[index] = CastOp<In, Out>{}(in[idx_in]);
}
}
} // namespace cu
void copy_general_input(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t offset_in,
int64_t offset_out,
const Shape& shape,
const Strides& strides_in) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_g_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in));
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
const_param(shape),
const_param(strides_in),
ndim);
}
});
});
});
}
} // namespace mlx::core

View File

@@ -34,14 +34,26 @@ CommandEncoder& DeviceStream::get_encoder() {
}
Device::Device(int device) : device_(device) {
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
&compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_));
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
&compute_capability_minor_, cudaDevAttrComputeCapabilityMinor, device_));
// Validate the requirements of device.
int attr = 0;
cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_);
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
&attr, cudaDevAttrConcurrentManagedAccess, device_));
if (attr != 1) {
throw std::runtime_error(fmt::format(
"Device {} does not support synchronization in managed memory.",
device_));
}
// The cublasLt handle is used by matmul.
make_current();
cublasLtCreate(&lt_);
}
Device::~Device() {
cublasLtDestroy(lt_);
}
void Device::make_current() {

View File

@@ -6,6 +6,7 @@
#include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h"
#include <cublasLt.h>
#include <thrust/execution_policy.h>
#include <unordered_map>
@@ -46,6 +47,7 @@ class DeviceStream {
class Device {
public:
explicit Device(int device);
~Device();
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
@@ -58,9 +60,21 @@ class Device {
int cuda_device() const {
return device_;
}
int compute_capability_major() const {
return compute_capability_major_;
}
int compute_capability_minor() const {
return compute_capability_minor_;
}
cublasLtHandle_t lt_handle() const {
return lt_;
}
private:
int device_;
int compute_capability_major_;
int compute_capability_minor_;
cublasLtHandle_t lt_;
std::unordered_map<int, DeviceStream> streams_;
};

View File

@@ -0,0 +1,121 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <thrust/iterator/iterator_adaptor.h>
#include <cuda/std/utility>
#include "mlx/backend/cuda/kernel_utils.cuh"
namespace mlx::core::cu {
// Iterating non-contiguous array.
template <typename Iterator, typename IdxT = int64_t>
class general_iterator
: public thrust::
iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator> {
public:
using super_t =
thrust::iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator>;
using reference = typename super_t::reference;
using difference_type = typename super_t::difference_type;
__host__ __device__ general_iterator(
Iterator it,
IdxT index,
int ndim,
Shape shape,
Strides strides)
: super_t(it),
index_(index),
ndim_(ndim),
shape_(cuda::std::move(shape)),
strides_(cuda::std::move(strides)) {}
__host__ __device__ IdxT index() const {
return index_;
}
__host__ __device__ const Shape& shape() const {
return shape_;
}
__host__ __device__ const Strides& strides() const {
return strides_;
}
private:
friend class thrust::iterator_core_access;
__host__ __device__ bool equal(const general_iterator& other) const {
return this->base() == other.base() && this->index() == other.index();
}
__host__ __device__ void advance(difference_type n) {
this->index_ += n;
}
__host__ __device__ void increment() {
this->index_ += 1;
}
__host__ __device__ void decrement() {
this->index_ -= 1;
}
__host__ __device__ difference_type
distance_to(const general_iterator& other) const {
_CCCL_ASSERT(
this->base() == other.base(),
"Underlying iterator must point to same base iterator");
return other.index() - this->index();
}
// The dereference is device-only to avoid accidental running in host.
__device__ typename super_t::reference dereference() const {
IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_);
return *(this->base() + offset);
}
IdxT index_;
int ndim_;
Shape shape_;
Strides strides_;
};
template <typename IdxT, typename Iterator>
__host__ __device__ auto make_general_iterator(
Iterator it,
IdxT index,
int ndim,
Shape shape,
Strides strides) {
return general_iterator<Iterator, IdxT>(
it, index, ndim, cuda::std::move(shape), cuda::std::move(strides));
}
template <typename IdxT, typename Iterator>
auto make_general_iterator(
Iterator it,
const std::vector<int32_t>& shape,
const std::vector<int64_t>& strides) {
return make_general_iterator<IdxT>(
it, 0, shape.size(), const_param(shape), const_param(strides));
}
template <typename IdxT, typename Iterator>
auto make_general_iterators(
Iterator it,
IdxT size,
const std::vector<int32_t>& shape,
const std::vector<int64_t>& strides) {
auto ndim = shape.size();
auto shape_arg = const_param(shape);
auto strides_arg = const_param(strides);
return std::make_pair(
make_general_iterator<IdxT>(it, 0, ndim, shape_arg, strides_arg),
make_general_iterator<IdxT>(it, size, ndim, shape_arg, strides_arg));
}
} // namespace mlx::core::cu

View File

@@ -7,13 +7,46 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/kernels/utils.cuh"
#include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <fmt/format.h>
#include <cuda/cmath>
namespace mlx::core {
// Convert a number between 1~3 to constexpr.
#define MLX_SWITCH_1_2_3(N, NDIM, ...) \
switch (N) { \
case 1: { \
constexpr int NDIM = 1; \
__VA_ARGS__; \
break; \
} \
case 2: { \
constexpr int NDIM = 2; \
__VA_ARGS__; \
break; \
} \
case 3: { \
constexpr int NDIM = 3; \
__VA_ARGS__; \
break; \
} \
}
// Like MLX_SWITCH_ALL_TYPES but for booleans.
#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \
if (BOOL) { \
constexpr bool BOOL_ALIAS = true; \
__VA_ARGS__; \
} else { \
constexpr bool BOOL_ALIAS = false; \
__VA_ARGS__; \
}
// Maps CPU types to CUDA types.
template <typename T>
struct CTypeToCudaType {
@@ -38,6 +71,24 @@ struct CTypeToCudaType<complex64_t> {
template <typename T>
using cuda_type_t = typename CTypeToCudaType<T>::type;
// Type traits for detecting floating numbers.
template <typename T>
inline constexpr bool is_floating_v =
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
// Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t>
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM));
}
cuda::std::array<T, NDIM> result;
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
// Compute the grid and block dimensions, check backend/common/utils.h for docs.
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);
@@ -46,4 +97,35 @@ dim3 get_2d_grid_dims(
const Strides& strides,
size_t divisor);
// Return a block size that achieves maximum potential occupancy for kernel.
template <typename T>
inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim;
CHECK_CUDA_ERROR(cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
return block_dim;
}
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
// assuming each thread handles |work_per_thread| elements of |arr|.
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
const array& arr,
bool large,
int work_per_thread = 1) {
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
uint block_dim = max_occupancy_block_dim(kernel);
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread);
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
} else {
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
}
return std::make_tuple(num_blocks, block_dim);
}
} // namespace mlx::core

View File

@@ -0,0 +1,278 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
#include <cuComplex.h>
#include <cuda/std/array>
namespace mlx::core::cu {
struct Add {
template <typename T>
__device__ T operator()(T x, T y) {
return x + y;
}
};
struct FloorDivide {
template <typename T>
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
return x / y;
} else {
return trunc(x / y);
}
}
};
struct Divide {
template <typename T>
__device__ T operator()(T x, T y) {
return x / y;
}
};
struct Remainder {
template <typename T>
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
if constexpr (cuda::std::is_signed_v<T>) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
} else {
return x % y;
}
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return x % y;
} else {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r = r + y;
}
return r;
}
}
};
struct Equal {
template <typename T>
__device__ bool operator()(T x, T y) {
return x == y;
}
};
struct NaNEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
if constexpr (std::is_same_v<T, cuComplex>) {
return x == y ||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) &&
isnan(cuCimagf(y))) ||
(cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) &&
isnan(cuCimagf(y))) ||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) &&
cuCimagf(x) == cuCimagf(y));
} else {
return x == y || (isnan(x) && isnan(y));
}
}
};
struct Greater {
template <typename T>
__device__ bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
__device__ bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
return x <= y;
}
};
struct LogAddExp {
template <typename T>
__device__ T operator()(T x, T y) {
if (isnan(x) || isnan(y)) {
return cuda::std::numeric_limits<T>::quiet_NaN();
}
T maxval = max(x, y);
T minval = min(x, y);
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
maxval == cuda::std::numeric_limits<T>::infinity())
? maxval
: T(float(maxval) + log1p(expf(minval - maxval)));
};
};
struct Maximum {
template <typename T>
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
return max(x, y);
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
return x;
}
return x > y ? x : y;
} else {
if (isnan(x)) {
return x;
}
return x > y ? x : y;
}
}
};
struct Minimum {
template <typename T>
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
return min(x, y);
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
return x;
}
return x < y ? x : y;
} else {
if (isnan(x)) {
return x;
}
return x < y ? x : y;
}
}
};
struct Multiply {
template <typename T>
__device__ T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
if constexpr (std::is_same_v<T, cuComplex>) {
return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y);
} else {
return x != y;
}
}
};
struct Power {
template <typename T>
__device__ T operator()(T base, T exp) {
if constexpr (cuda::std::is_integral_v<T>) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto x_theta = atan2f(base.y, base.x);
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
auto phase = exp.y * x_ln_r + exp.x * x_theta;
return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase));
} else {
return powf(base, exp);
}
}
};
struct Subtract {
template <typename T>
__device__ T operator()(T x, T y) {
return x - y;
}
};
struct LogicalAnd {
template <typename T>
__device__ T operator()(T x, T y) {
return x && y;
};
};
struct LogicalOr {
template <typename T>
__device__ T operator()(T x, T y) {
return x || y;
};
};
struct BitwiseAnd {
template <typename T>
__device__ T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
__device__ T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
__device__ T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
__device__ T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
__device__ T operator()(T x, T y) {
return x >> y;
};
};
struct ArcTan2 {
template <typename T>
__device__ T operator()(T y, T x) {
return atan2f(y, x);
}
};
struct DivMod {
template <typename T>
__device__ cuda::std::array<T, 2> operator()(T x, T y) {
return {FloorDivide{}(x, y), Remainder{}(x, y)};
};
};
} // namespace mlx::core::cu

View File

@@ -0,0 +1,59 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuComplex.h>
#include <thrust/iterator/transform_iterator.h>
namespace mlx::core::cu {
// An op that does static_cast, with custom conversions for some types.
template <typename SrcT, typename DstT, typename = void>
struct CastOp {
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, DstT>;
__device__ DstT operator()(SrcT x) {
return static_cast<DstT>(x);
}
};
// Converting a complex number to real number discards the imaginary part.
template <typename DstT>
struct CastOp<
cuComplex,
DstT,
cuda::std::enable_if_t<!cuda::std::is_same_v<cuComplex, DstT>>> {
static constexpr bool is_castable = cuda::std::is_convertible_v<float, DstT>;
__device__ DstT operator()(cuComplex x) {
static_assert(!cuda::std::is_same_v<cuComplex, DstT>);
return static_cast<DstT>(cuCrealf(x));
}
};
// Allow converting a real number to complex number.
template <typename SrcT>
struct CastOp<
SrcT,
cuComplex,
cuda::std::enable_if_t<!cuda::std::is_same_v<SrcT, cuComplex>>> {
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, float>;
__device__ cuComplex operator()(SrcT x) {
static_assert(!cuda::std::is_same_v<SrcT, cuComplex>);
return cuComplex{static_cast<float>(x), 0};
}
};
// Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator>
__host__ __device__ auto make_cast_iterator(Iterator it) {
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
if constexpr (std::is_same_v<SrcT, DstT>) {
return it;
} else {
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
}
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,240 @@
// Copyright © 2025 Apple Inc.
// Copyright © 2017-2024 The Simons Foundation, Inc.
//
// FINUFFT is licensed under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance with the
// License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Forked from
// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h
#pragma once
#include <cuComplex.h>
// This header provides some helper functions for cuComplex types.
// It mainly wraps existing CUDA implementations to provide operator overloads
// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are
// all provided by CUDA
__forceinline__ __host__ __device__ cuDoubleComplex
operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCadd(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCsub(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCmul(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCdiv(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) {
double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b));
double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b));
return make_cuDoubleComplex(r, i);
}
__forceinline__ __host__ __device__ bool operator==(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b);
}
__forceinline__ __host__ __device__ bool operator!=(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return !(a == b);
}
__forceinline__ __host__ __device__ bool operator>(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a));
double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b));
return mag_a > mag_b;
}
__forceinline__ __host__ __device__ bool operator>=(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return a > b || a == b;
}
__forceinline__ __host__ __device__ bool operator<(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return b > a;
}
__forceinline__ __host__ __device__ bool operator<=(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return b > a || a == b;
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator+(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator+(double a, const cuDoubleComplex& b) {
return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator-(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator-(double a, const cuDoubleComplex& b) {
return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator*(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator*(double a, const cuDoubleComplex& b) {
return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator/(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator/(double a, const cuDoubleComplex& b) {
double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b);
return make_cuDoubleComplex(
(a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator+(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCaddf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator-(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCsubf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator*(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCmulf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator/(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCdivf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator%(const cuFloatComplex& a, const cuFloatComplex& b) {
float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b));
float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b));
return make_cuFloatComplex(r, i);
}
__forceinline__ __host__ __device__ bool operator==(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b);
}
__forceinline__ __host__ __device__ bool operator!=(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return !(a == b);
}
__forceinline__ __host__ __device__ bool operator>(
const cuFloatComplex& a,
const cuFloatComplex& b) {
float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a));
float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b));
return mag_a > mag_b;
}
__forceinline__ __host__ __device__ bool operator>=(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return a > b || a == b;
}
__forceinline__ __host__ __device__ bool operator<(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return b > a;
}
__forceinline__ __host__ __device__ bool operator<=(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return b > a || a == b;
}
__forceinline__ __host__ __device__ cuFloatComplex
operator+(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator+(float a, const cuFloatComplex& b) {
return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator-(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator-(float a, const cuFloatComplex& b) {
return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator*(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator*(float a, const cuFloatComplex& b) {
return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator/(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator/(float a, const cuFloatComplex& b) {
float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b);
return make_cuFloatComplex(
(a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom);
}

View File

@@ -9,6 +9,124 @@
namespace mlx::core::cu {
///////////////////////////////////////////////////////////////////////////////
// Unary ops for half types.
///////////////////////////////////////////////////////////////////////////////
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
template <typename T> \
__forceinline__ __device__ auto NAME(T x) { \
if constexpr (cuda::std::is_same_v<T, __half>) { \
return HALF_OP(x); \
} else { \
return ::NAME(x); \
} \
}
#else
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
template <typename T> \
__forceinline__ __device__ auto NAME(T x) { \
if constexpr (cuda::std::is_same_v<T, __half>) { \
return HALF_OP(x); \
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
return HALF_OP(x); \
} else { \
return ::NAME(x); \
} \
}
#endif
#define MLX_DEFINE_UNARY_OP_FALLBCK(NAME) \
template <typename T> \
__forceinline__ __device__ auto NAME(T x) { \
if constexpr (cuda::std::is_same_v<T, __half>) { \
return ::NAME(__half2float(x)); \
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
return ::NAME(__bfloat162float(x)); \
} else { \
return ::NAME(x); \
} \
}
MLX_DEFINE_UNARY_OP(abs, __habs)
MLX_DEFINE_UNARY_OP(ceil, hceil)
MLX_DEFINE_UNARY_OP(cos, hcos)
MLX_DEFINE_UNARY_OP(exp, hexp)
MLX_DEFINE_UNARY_OP(floor, hfloor)
MLX_DEFINE_UNARY_OP(isnan, __hisnan)
MLX_DEFINE_UNARY_OP(log, hlog)
MLX_DEFINE_UNARY_OP(log2, hlog2)
MLX_DEFINE_UNARY_OP(log10, hlog10)
MLX_DEFINE_UNARY_OP(rint, hrint)
MLX_DEFINE_UNARY_OP(rsqrt, hrsqrt)
MLX_DEFINE_UNARY_OP(sin, hsin)
MLX_DEFINE_UNARY_OP(sqrt, hsqrt)
MLX_DEFINE_UNARY_OP_FALLBCK(acos)
MLX_DEFINE_UNARY_OP_FALLBCK(acosh)
MLX_DEFINE_UNARY_OP_FALLBCK(asin)
MLX_DEFINE_UNARY_OP_FALLBCK(asinh)
MLX_DEFINE_UNARY_OP_FALLBCK(atan)
MLX_DEFINE_UNARY_OP_FALLBCK(atanh)
MLX_DEFINE_UNARY_OP_FALLBCK(cosh)
MLX_DEFINE_UNARY_OP_FALLBCK(log1p)
MLX_DEFINE_UNARY_OP_FALLBCK(sinh)
MLX_DEFINE_UNARY_OP_FALLBCK(tan)
#if __CUDA_ARCH__ >= 1280
MLX_DEFINE_UNARY_OP(tanh, htanh)
#else
MLX_DEFINE_UNARY_OP_FALLBCK(tanh)
#endif
#undef MLX_DEFINE_UNARY_OP
#undef MLX_DEFINE_UNARY_OP_FALLBCK
///////////////////////////////////////////////////////////////////////////////
// Binary ops for half types.
///////////////////////////////////////////////////////////////////////////////
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
template <typename T> \
__forceinline__ __device__ auto NAME(T x, T y) { \
if constexpr (cuda::std::is_same_v<T, __half>) { \
return HALF_OP(x, y); \
} else { \
return ::NAME(x, y); \
} \
}
#else
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
template <typename T> \
__forceinline__ __device__ auto NAME(T x, T y) { \
if constexpr (cuda::std::is_same_v<T, __half>) { \
return HALF_OP(x, y); \
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
return HALF_OP(x, y); \
} else { \
return ::NAME(x, y); \
} \
}
#endif
MLX_DEFINE_BINARY_OP(max, __hmax)
MLX_DEFINE_BINARY_OP(min, __hmin)
#undef MLX_DEFINE_BINARY_OP
template <typename T>
__forceinline__ __device__ T fmod(T x, T y) {
if constexpr (cuda::std::is_same_v<T, __half>) {
return __float2half(::fmod(__half2float(x), __half2float(y)));
#if CUDART_VERSION >= 12000 || __CUDA_ARCH__ >= 800
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
return __float2bfloat16(::fmod(__bfloat162float(x), __bfloat162float(y)));
#endif
} else {
return ::fmod(x, y);
}
}
///////////////////////////////////////////////////////////////////////////////
// Additional C++ operator overrides between half types and native types.
///////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,349 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
#include "mlx/backend/cuda/kernels/utils.cuh"
namespace mlx::core::cu {
struct Abs {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_unsigned_v<T>) {
return x;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0};
} else {
return abs(x);
}
}
};
struct ArcCos {
template <typename T>
__device__ T operator()(T x) {
return acos(x);
}
};
struct ArcCosh {
template <typename T>
__device__ T operator()(T x) {
return acosh(x);
}
};
struct ArcSin {
template <typename T>
__device__ T operator()(T x) {
return asin(x);
}
};
struct ArcSinh {
template <typename T>
__device__ T operator()(T x) {
return asinh(x);
}
};
struct ArcTan {
template <typename T>
__device__ T operator()(T x) {
return atan(x);
}
};
struct ArcTanh {
template <typename T>
__device__ T operator()(T x) {
return atanh(x);
}
};
struct BitwiseInvert {
template <typename T>
__device__ T operator()(T x) {
return ~x;
}
};
struct Ceil {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_integral_v<T>) {
return x;
} else {
return ceil(x);
}
}
};
struct Conjugate {
__device__ cuComplex operator()(cuComplex x) {
return {cuCrealf(x), -cuCimagf(x)};
}
};
struct Cos {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
cos(cuCrealf(x)) * cosh(cuCimagf(x)),
-sin(cuCrealf(x)) * sinh(cuCimagf(x))};
} else {
return cos(x);
}
}
};
struct Cosh {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
cosh(cuCrealf(x)) * cos(cuCimagf(x)),
sinh(cuCrealf(x)) * sin(cuCimagf(x))};
} else {
return cosh(x);
}
}
};
struct Erf {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, __half>) {
return erf(__half2float(x));
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
return erf(__bfloat162float(x));
} else {
return erf(x);
}
}
};
struct ErfInv {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, __half>) {
return erfinv(__half2float(x));
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
return erfinv(__bfloat162float(x));
} else {
return erfinv(x);
}
}
};
struct Exp {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto m = exp(cuCrealf(x));
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
} else {
return exp(x);
}
}
};
struct Expm1 {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, __half>) {
return expm1(__half2float(x));
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
return expm1(__bfloat162float(x));
} else {
return expm1(x);
}
}
};
struct Floor {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_integral_v<T>) {
return x;
} else {
return floor(x);
}
}
};
struct Imag {
__device__ float operator()(cuComplex x) {
return cuCimagf(x);
}
};
struct Log {
template <typename T>
__device__ T operator()(T x) {
return log(x);
}
};
struct Log2 {
template <typename T>
__device__ T operator()(T x) {
return log2(x);
}
};
struct Log10 {
template <typename T>
__device__ T operator()(T x) {
return log10(x);
}
};
struct Log1p {
template <typename T>
__device__ T operator()(T x) {
return log1p(x);
}
};
struct LogicalNot {
__device__ bool operator()(bool x) {
return !x;
}
};
struct Negative {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return 0 - x;
} else {
return -x;
}
}
};
struct Real {
__device__ float operator()(cuComplex x) {
return cuCrealf(x);
}
};
struct Round {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {rint(cuCrealf(x)), rint(cuCimagf(x))};
} else {
return rint(x);
}
}
};
struct Rsqrt {
template <typename T>
__device__ T operator()(T x) {
return rsqrt(x);
}
};
struct Sigmoid {
template <typename T>
__device__ T operator()(T x) {
T y = 1 / (1 + exp(-abs(x)));
return (x < 0) ? 1 - y : y;
}
};
struct Sign {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_unsigned_v<T>) {
return x != 0;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (cuCrealf(x) == 0 && cuCimagf(x) == 0) {
return x;
} else {
return x / Abs()(x);
}
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
return static_cast<float>((x > T(0.f)) - (x < T(0.f)));
} else {
return (x > T(0)) - (x < T(0));
}
}
};
struct Sin {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
sin(cuCrealf(x)) * cosh(cuCimagf(x)),
cos(cuCrealf(x)) * sinh(cuCimagf(x))};
} else {
return sin(x);
}
}
};
struct Sinh {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
sinh(cuCrealf(x)) * cos(cuCimagf(x)),
cosh(cuCrealf(x)) * sin(cuCimagf(x))};
} else {
return sinh(x);
}
}
};
struct Square {
template <typename T>
__device__ T operator()(T x) {
return x * x;
}
};
struct Sqrt {
template <typename T>
__device__ T operator()(T x) {
return sqrt(x);
}
};
struct Tan {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
float tan_a = tan(cuCrealf(x));
float tanh_b = tanh(cuCimagf(x));
float t1 = tan_a * tanh_b;
float denom = 1. + t1 * t1;
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
} else {
return tan(x);
}
}
};
struct Tanh {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
float tanh_a = tanh(cuCrealf(x));
float tan_b = tan(cuCimagf(x));
float t1 = tanh_a * tan_b;
float denom = 1. + t1 * t1;
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
} else {
return tanh(x);
}
}
};
} // namespace mlx::core::cu

View File

@@ -0,0 +1,104 @@
// Copyright © 2025 Apple Inc.
// This file must not include any host-only code, utilies that work under both
// host and device can be put here.
//
// See more about the requirements at:
// https://docs.nvidia.com/cuda/nvrtc/#language
#pragma once
#include <cuComplex.h>
#include <cuda/std/array>
#include <cuda/std/limits>
#include <cuda/std/tuple>
namespace mlx::core::cu {
///////////////////////////////////////////////////////////////////////////////
// CUDA kernel utils
///////////////////////////////////////////////////////////////////////////////
// To pass shape/strides to kernels via constant memory, their size must be
// known at compile time.
#define MAX_NDIM 8
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
///////////////////////////////////////////////////////////////////////////////
// Indexing utils
///////////////////////////////////////////////////////////////////////////////
template <typename IdxT = int64_t>
inline __host__ __device__ IdxT
elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
IdxT loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
// Optimize when the ndim is known at compile time.
template <int NDIM, typename IdxT = int64_t>
inline __host__ __device__ IdxT
elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) {
IdxT loc = 0;
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
template <int NDIM, typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
IdxT elem,
const int* shape,
const int64_t* a_strides,
const int64_t* b_strides) {
IdxT a_loc = 0;
IdxT b_loc = 0;
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc);
}
// Optimized version when ndim is larger than 4.
template <typename IdxT = int64_t>
inline __host__ __device__ IdxT
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
IdxT loc = elem_to_loc_nd<3>(elem, shape, strides);
for (int i = ndim - 1; i >= 3; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
template <typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
IdxT elem,
const int* shape,
const int64_t* a_strides,
const int64_t* b_strides,
int ndim) {
auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides);
for (int i = ndim - 1; i >= 3; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc);
}
} // namespace mlx::core::cu

474
mlx/backend/cuda/matmul.cpp Normal file
View File

@@ -0,0 +1,474 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cublasLt.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
#include <numeric>
namespace mlx::core {
namespace cu {
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
void check_cublas_error(const char* name, cublasStatus_t err) {
if (err != CUBLAS_STATUS_SUCCESS) {
// TODO: Use cublasGetStatusString when it is widely available.
throw std::runtime_error(
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
}
}
class MatMul {
public:
MatMul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto type = dtype_to_cuda_type(dtype);
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), type));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(int32_t)));
cublasOperation_t op = CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&op,
sizeof(cublasOperation_t)));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB,
&op,
sizeof(cublasOperation_t)));
a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
b_desc_ = create_matrix_layout(
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
MatMul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
bool c_transposed,
int64_t ldc,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride,
int64_t c_batch_stride)
: MatMul(
device,
dtype,
a_transposed,
a_rows,
a_cols,
lda,
b_transposed,
b_rows,
b_cols,
ldb,
batch_count,
a_batch_stride,
b_batch_stride) {
auto type = dtype_to_cuda_type(dtype);
c_desc_ = create_matrix_layout(
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
}
~MatMul() {
cublasLtMatrixLayoutDestroy(a_desc_);
cublasLtMatrixLayoutDestroy(b_desc_);
cublasLtMatrixLayoutDestroy(c_desc_);
cublasLtMatrixLayoutDestroy(out_desc_);
cublasLtMatmulDescDestroy(matmul_desc_);
}
void run(
cu::CommandEncoder& encoder,
void* out,
void* a,
void* b,
void* c = nullptr,
float alpha = 1,
float beta = 0) {
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
int ret = 0;
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
encoder.device().lt_handle(),
matmul_desc_,
a_desc_,
b_desc_,
out_desc_,
out_desc_,
pref_,
1,
&heuristic_,
&ret));
if (ret == 0) {
throw std::runtime_error("Can not find algorithm for matmul.");
}
}
array workspace(
allocator::malloc(heuristic_.workspaceSize),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
encoder.launch_kernel([&](cudaStream_t stream) {
CHECK_CUBLAS_ERROR(cublasLtMatmul(
encoder.device().lt_handle(),
matmul_desc_,
&alpha,
a,
a_desc_,
b,
b_desc_,
&beta,
c ? c : out,
c ? c_desc_ : out_desc_,
out,
out_desc_,
&heuristic_.algo,
workspace.data<void>(),
workspace.nbytes(),
stream));
});
}
private:
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) {
case uint8:
case uint16:
case int8:
case int16:
case int32:
return CUBLAS_COMPUTE_32I;
case float16:
case bfloat16:
return CUBLAS_COMPUTE_16F;
case float32:
return CUBLAS_COMPUTE_32F;
case float64:
case complex64:
return CUBLAS_COMPUTE_64F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
}
}
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
switch (dtype) {
case uint8:
return CUDA_R_8U;
case uint16:
return CUDA_R_16U;
case int8:
return CUDA_R_8I;
case int16:
return CUDA_R_16I;
case int32:
return CUDA_R_32I;
case float16:
return CUDA_R_16F;
case bfloat16:
return CUDA_R_16BF;
case float32:
return CUDA_R_32F;
case float64:
return CUDA_R_64F;
case complex64:
return CUDA_C_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
}
}
cublasLtMatrixLayout_t create_matrix_layout(
cudaDataType_t type,
uint64_t rows,
uint64_t cols,
bool transposed,
int64_t ld,
int32_t batch_count,
int64_t batch_stride) {
cublasLtMatrixLayout_t desc;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order =
transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch_count,
sizeof(int32_t)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&batch_stride,
sizeof(int64_t)));
}
return desc;
}
cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr};
cublasLtMatrixLayout_t b_desc_{nullptr};
cublasLtMatrixLayout_t c_desc_{nullptr};
cublasLtMatrixLayout_t out_desc_{nullptr};
cublasLtMatmulHeuristicResult_t heuristic_;
};
} // namespace cu
namespace {
std::tuple<bool, int64_t, array>
check_transpose(std::vector<array>& copies, const Stream& s, const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1 && stx == arr.shape(-1)) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy);
}
}
} // namespace
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Matmul::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
assert(inputs.size() == 2);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
// Return 0s if either input is empty.
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero(0, a_pre.dtype());
encoder.add_temporary(zero);
fill_gpu(zero, out, s);
return;
}
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
int K = a_pre.shape(-1);
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
for (auto& temp : copies) {
encoder.add_temporary(temp);
}
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
cu::MatMul matmul(
encoder.device(),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
matmul.run(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc);
a_it.step();
b_it.step();
}
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("AddMM::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
assert(inputs.size() == 3);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto& c_pre = inputs[2];
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
int K = a_pre.shape(-1);
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
auto [c_transposed, ldc, c] = check_transpose(copies, s, c_pre);
for (auto& temp : copies) {
encoder.add_temporary(temp);
}
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides, c_batch_strides] =
collapse_batches(a, b, c);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
c_batch_strides.back() == M * c.strides()[c.ndim() - 2] &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
c_batch_strides = {0};
batch_shape = {1};
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
cu::MatMul matmul(
encoder.device(),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
c_transposed,
ldc,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back(),
c_batch_strides.back());
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
matmul.run(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
c.data<int8_t>() + c.itemsize() * c_it.loc,
alpha_,
beta_);
a_it.step();
b_it.step();
c_it.step();
}
}
} // namespace mlx::core

View File

@@ -71,90 +71,34 @@ bool fast::ScaledDotProductAttention::use_fallback(
throw std::runtime_error(#func " has no CUDA implementation."); \
}
NO_GPU(Abs)
NO_GPU(Add)
NO_GPU(AddMM)
NO_GPU(ArcCos)
NO_GPU(ArcCosh)
NO_GPU(ArcSin)
NO_GPU(ArcSinh)
NO_GPU(ArcTan)
NO_GPU(ArcTan2)
NO_GPU(ArcTanh)
NO_GPU(ArgPartition)
NO_GPU(ArgReduce)
NO_GPU(ArgSort)
NO_GPU(BitwiseBinary)
NO_GPU(BitwiseInvert)
NO_GPU(BlockMaskedMM)
NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
NO_GPU(Conjugate)
NO_GPU(Convolution)
NO_GPU(Cos)
NO_GPU(Cosh)
NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(Remainder)
NO_GPU(Equal)
NO_GPU(Erf)
NO_GPU(ErfInv)
NO_GPU(Exp)
NO_GPU(Expm1)
NO_GPU(FFT)
NO_GPU(Floor)
NO_GPU(Gather)
NO_GPU(GatherAxis)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Greater)
NO_GPU(GreaterEqual)
NO_GPU(Hadamard)
NO_GPU(Imag)
NO_GPU(Less)
NO_GPU(LessEqual)
NO_GPU(Load)
NO_GPU(Log)
NO_GPU(Log1p)
NO_GPU(LogicalNot)
NO_GPU(LogicalAnd)
NO_GPU(LogicalOr)
NO_GPU(LogAddExp)
NO_GPU(LogSumExp)
NO_GPU_MULTI(LUF)
NO_GPU(Matmul)
NO_GPU(Maximum)
NO_GPU(Minimum)
NO_GPU(Multiply)
NO_GPU(Negative)
NO_GPU(NotEqual)
NO_GPU(Partition)
NO_GPU(Power)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(RandomBits)
NO_GPU(Real)
NO_GPU(Reduce)
NO_GPU(Round)
NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(ScatterAxis)
NO_GPU(Select)
NO_GPU(Sigmoid)
NO_GPU(Sign)
NO_GPU(Sin)
NO_GPU(Sinh)
NO_GPU(SliceUpdate)
NO_GPU(Softmax)
NO_GPU(Sort)
NO_GPU(Square)
NO_GPU(Sqrt)
NO_GPU(Subtract)
NO_GPU_MULTI(SVD)
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eig)

181
mlx/backend/cuda/random.cu Normal file
View File

@@ -0,0 +1,181 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <cassert>
namespace mlx::core {
namespace cu {
__constant__ constexpr uint32_t rotations[2][4] = {
{13, 15, 26, 6},
{17, 29, 16, 24}};
union rbits {
uint2 val;
uint8_t bytes[2][4];
};
__device__ rbits threefry2x32_hash(uint2 key, uint2 count) {
uint32_t ks[] = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
rbits v;
v.val.x = count.x + ks[0];
v.val.y = count.y + ks[1];
for (int i = 0; i < 5; ++i) {
for (auto r : rotations[i % 2]) {
v.val.x += v.val.y;
v.val.y = (v.val.y << r) | (v.val.y >> (32 - r));
v.val.y ^= v.val.x;
}
v.val.x += ks[(i + 1) % 3];
v.val.y += ks[(i + 2) % 3] + i + 1;
}
return v;
}
__global__ void rbitsc(
const uint32_t* keys,
uint8_t* out,
dim3 grid_dims,
bool odd,
uint32_t bytes_per_key) {
uint2 index{
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y};
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
return;
}
auto kidx = 2 * index.x;
auto key = uint2{keys[kidx], keys[kidx + 1]};
auto half_size = grid_dims.y - odd;
out += index.x * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
auto bits = threefry2x32_hash(
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
size_t idx = size_t(index.y) << 2;
for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[0][i];
}
if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[idx + i] = bits.bytes[1][i];
}
} else {
for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[1][i];
}
}
}
}
__global__ void rbits(
const uint32_t* keys,
uint8_t* out,
dim3 grid_dims,
bool odd,
uint32_t bytes_per_key,
int32_t ndim,
const __grid_constant__ Shape key_shape,
const __grid_constant__ Strides key_strides) {
uint2 index{
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y};
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
return;
}
auto kidx = 2 * index.x;
auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim);
auto k2_elem =
elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim);
auto key = uint2{keys[k1_elem], keys[k2_elem]};
auto half_size = grid_dims.y - odd;
out += size_t(index.x) * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
auto bits = threefry2x32_hash(
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
size_t idx = size_t(index.y) << 2;
for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[0][i];
}
if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[idx + i] = bits.bytes[1][i];
}
} else {
for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[1][i];
}
}
}
}
} // namespace cu
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("RandomBits::eval_gpu");
assert(inputs.size() == 1);
// keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0];
uint32_t num_keys = keys.size() / 2;
uint32_t elems_per_key = out.size() / num_keys;
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4;
uint32_t half_size = out_per_key / 2;
bool odd = out_per_key % 2;
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(keys);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
dim3 grid_dims{num_keys, half_size + odd};
dim3 block_dims = get_block_dims(grid_dims.x, grid_dims.y, 1);
dim3 num_blocks{
cuda::ceil_div(grid_dims.x, block_dims.x),
cuda::ceil_div(grid_dims.y, block_dims.y)};
if (keys.flags().row_contiguous) {
cu::rbitsc<<<num_blocks, block_dims, 0, stream>>>(
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,
odd,
bytes_per_key);
} else {
cu::rbits<<<num_blocks, block_dims, 0, stream>>>(
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,
odd,
bytes_per_key,
keys.ndim(),
const_param(keys.shape()),
const_param(keys.strides()));
}
});
}
} // namespace mlx::core

View File

@@ -1,7 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include <numeric>
namespace mlx::core {
void concatenate_gpu(
@@ -9,7 +13,29 @@ void concatenate_gpu(
array& out,
int axis,
const Stream& s) {
throw std::runtime_error("concatenate_gpu not implemented in CUDA backend.");
std::vector<int> sizes;
sizes.push_back(0);
for (auto& p : inputs) {
sizes.push_back(p.shape(axis));
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
// TODO: Handle concurrent outputs:
// https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s);
}
}
} // namespace mlx::core

180
mlx/backend/cuda/sort.cu Normal file
View File

@@ -0,0 +1,180 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
#include <cub/device/device_segmented_sort.cuh>
#include <cassert>
#include <numeric>
namespace mlx::core {
namespace {
template <typename T>
struct ModOp {
T divisor;
__device__ T operator()(T x) {
return x % divisor;
}
};
// We can not use any op in eval, make an utility.
array swapaxes_in_eval(const array& in, int axis1, int axis2) {
std::vector<int> axes(in.ndim());
std::iota(axes.begin(), axes.end(), 0);
std::swap(axes[axis1], axes[axis2]);
// TODO: Share the code with Transpose::eval.
Shape shape(axes.size());
Strides strides(in.ndim());
for (size_t ax = 0; ax < axes.size(); ++ax) {
shape[ax] = in.shape()[axes[ax]];
strides[ax] = in.strides()[axes[ax]];
}
auto flags = in.flags();
if (flags.contiguous) {
auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides);
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
}
array out(shape, in.dtype(), nullptr, {});
out.copy_shared_buffer(in, strides, flags, in.data_size());
return out;
}
template <typename... Args>
void segmented_sort_pairs(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(
cub::DeviceSegmentedSort::StableSortPairs(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
temp.data<void>(), size, args...));
}
template <typename... Args>
void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(
cub::DeviceSegmentedSort::StableSortKeys(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
temp.data<void>(), size, args...));
}
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array out = out_;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (axis < 0) {
axis += in.ndim();
}
int nsort = in.shape(axis);
int nsegments = in.data_size() / nsort;
int last_dim = in.ndim() - 1;
// If we are not sorting the innermost dimension of a contiguous array,
// transpose and make a copy.
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
if (!is_segmented_sort) {
array trans = swapaxes_in_eval(in, axis, last_dim);
in = array(trans.shape(), trans.dtype(), nullptr, {});
copy_gpu(trans, in, CopyType::General, s);
encoder.add_temporary(in);
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(out);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
using Type = cuda_type_t<CTYPE>;
auto offsets = thrust::make_transform_iterator(
thrust::make_counting_iterator(0),
[nsort] __device__(int i) { return i * nsort; });
if (argsort) {
// Indices in the sorted dimension.
array indices(
allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(indices);
thrust::transform(
cu::thrust_policy(stream),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(indices.data_size()),
thrust::device_pointer_cast(indices.data<uint32_t>()),
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
// In argsort though we don't need the result of sorted values, the
// API requires us to provide an array to store it.
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
encoder.add_temporary(discard);
segmented_sort_pairs(
encoder,
in.data<Type>(),
discard.data<Type>(),
indices.data<uint32_t>(),
out.data<uint32_t>(),
in.data_size(),
nsegments,
offsets,
offsets + 1,
stream);
} else {
segmented_sort(
encoder,
in.data<Type>(),
out.data<Type>(),
in.data_size(),
nsegments,
offsets,
offsets + 1,
stream);
}
} else {
throw std::runtime_error(
"CUDA backend does not support sorting complex numbers");
}
});
});
if (!is_segmented_sort) {
// Swap the sorted axis back.
// TODO: Do in-place transpose instead of using a temporary out array.
copy_gpu(swapaxes_in_eval(out, axis, last_dim), out_, CopyType::General, s);
}
}
} // namespace
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgSort::eval_gpu");
assert(inputs.size() == 1);
gpu_sort(stream(), inputs[0], out, axis_, true);
}
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Sort::eval_gpu");
assert(inputs.size() == 1);
gpu_sort(stream(), inputs[0], out, axis_, false);
}
} // namespace mlx::core

196
mlx/backend/cuda/unary.cu Normal file
View File

@@ -0,0 +1,196 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernels/unary_ops.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
namespace mlx::core {
namespace cu {
template <typename Op, typename In, typename Out>
constexpr bool supports_unary_op() {
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
std::is_same_v<Op, Sign>) {
return std::is_same_v<In, Out>;
}
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcCosh> ||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Sigmoid> ||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseInvert>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor> ||
std::is_same_v<Op, Square>) {
return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>;
}
if (std::is_same_v<Op, Conjugate>) {
return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>;
}
if (std::is_same_v<Op, Cos> || std::is_same_v<Op, Cosh> ||
std::is_same_v<Op, Exp> || std::is_same_v<Op, Round> ||
std::is_same_v<Op, Sin> || std::is_same_v<Op, Sinh> ||
std::is_same_v<Op, Tan> || std::is_same_v<Op, Tanh>) {
return std::is_same_v<In, Out> &&
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
}
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>;
}
if (std::is_same_v<Op, LogicalNot>) {
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
}
return false;
}
} // namespace cu
template <typename Op>
void unary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s) {
auto& in = inputs[0];
if (in.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto policy = cu::thrust_policy(stream);
auto in_ptr = thrust::device_pointer_cast(in.data<InType>());
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
if (in.flags().contiguous) {
thrust::transform(
policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op());
} else {
auto [shape, strides] = collapse_contiguous_dims(in);
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
in_ptr, in.data_size(), shape, strides);
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
}
} else {
throw std::runtime_error(fmt::format(
"Can not do unary op {} on input of {} with output of {}.",
op,
dtype_to_string(in.dtype()),
dtype_to_string(out.dtype())));
}
});
});
});
}
template <typename Op>
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s) {
set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace<Op>(inputs, out, op, s);
}
#define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
unary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
}
UNARY_GPU(Abs)
UNARY_GPU(ArcCos)
UNARY_GPU(ArcCosh)
UNARY_GPU(ArcSin)
UNARY_GPU(ArcSinh)
UNARY_GPU(ArcTan)
UNARY_GPU(ArcTanh)
UNARY_GPU(BitwiseInvert)
UNARY_GPU(Ceil)
UNARY_GPU(Conjugate)
UNARY_GPU(Cos)
UNARY_GPU(Cosh)
UNARY_GPU(Erf)
UNARY_GPU(ErfInv)
UNARY_GPU(Exp)
UNARY_GPU(Expm1)
UNARY_GPU(Floor)
UNARY_GPU(Imag)
UNARY_GPU(Log1p)
UNARY_GPU(LogicalNot)
UNARY_GPU(Negative)
UNARY_GPU(Real)
UNARY_GPU(Sigmoid)
UNARY_GPU(Sign)
UNARY_GPU(Sin)
UNARY_GPU(Sinh)
UNARY_GPU(Square)
UNARY_GPU(Tan)
UNARY_GPU(Tanh)
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Log::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (base_) {
case Base::e:
unary_op_gpu<cu::Log>(inputs, out, op, s);
break;
case Base::two:
unary_op_gpu<cu::Log2>(inputs, out, op, s);
break;
case Base::ten:
unary_op_gpu<cu::Log10>(inputs, out, op, s);
break;
}
}
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Round::eval_gpu");
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto& s = out.primitive().stream();
if (issubdtype(in.dtype(), inexact)) {
unary_op_gpu<cu::Round>(inputs, out, get_primitive_string(this), s);
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Sort::eval_gpu");
auto& s = out.primitive().stream();
if (recip_) {
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);
} else {
unary_op_gpu<cu::Sqrt>(inputs, out, "Sqrt", s);
}
}
} // namespace mlx::core

View File

@@ -5,9 +5,17 @@
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#if defined(MLX_USE_CUDA)
#include <nvtx3/nvtx3.hpp>
#endif
#include <cassert>
#if defined(MLX_USE_CUDA)
#define MLX_PROFILER_RANGE(message) nvtx3::scoped_range r(message)
#else
#define MLX_PROFILER_RANGE(message)
#endif
namespace mlx::core {

View File

@@ -31,13 +31,13 @@ std::string get_kernel_name(
kname = "ss";
break;
case BinaryOpType::ScalarVector:
kname = (large ? "sv2" : "sv");
kname = "sv";
break;
case BinaryOpType::VectorScalar:
kname = (large ? "vs2" : "vs");
kname = "vs";
break;
case BinaryOpType::VectorVector:
kname = (large ? "vv2" : "vv");
kname = "vv";
break;
case BinaryOpType::General:
kname = "g";
@@ -51,6 +51,13 @@ std::string get_kernel_name(
}
break;
}
if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) {
if (large) {
kname += "2";
} else if (work_per_thread > 1) {
kname += "n";
}
}
concatenate(kname, "_", op, type_to_name(a));
return kname;
}
@@ -90,7 +97,7 @@ void binary_op_gpu_inplace(
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > UINT32_MAX;
work_per_thread = get_work_per_thread(a.dtype());
work_per_thread = get_work_per_thread(a.dtype(), out.data_size());
}
std::string kernel_name =
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);

View File

@@ -278,7 +278,21 @@ void Compiled::eval_gpu(
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ work_per_thread);
/* work_per_thread = */ 1);
if (work_per_thread > 1) {
build_kernel(
kernel,
kernel_lib_ + "_contiguous_n",
inputs_,
outputs_,
tape_,
is_constant_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ work_per_thread);
}
build_kernel(
kernel,
kernel_lib_ + "_contiguous_large",
@@ -358,12 +372,20 @@ void Compiled::eval_gpu(
int ndim = shape.size();
bool dynamic = ndim >= 8;
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
int work_per_thread = 1;
if (!contiguous) {
if (dynamic) {
kernel_name += "dynamic";
} else {
kernel_name += std::to_string(shape.size());
}
work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
} else {
work_per_thread =
get_work_per_thread(outputs[0].dtype(), outputs[0].data_size());
if (work_per_thread > 1 && !large) {
kernel_name += "_n";
}
}
if (large) {
kernel_name += "_large";
@@ -420,7 +442,6 @@ void Compiled::eval_gpu(
// Launch the kernel
if (contiguous) {
int work_per_thread = get_work_per_thread(outputs[0].dtype());
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
@@ -433,7 +454,6 @@ void Compiled::eval_gpu(
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1);
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
int pow2;

View File

@@ -677,7 +677,7 @@ void depthwise_conv_2D_gpu(
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);

View File

@@ -55,10 +55,10 @@ void copy_gpu_inplace(
std::string kernel_name;
switch (ctype) {
case CopyType::Scalar:
kernel_name = (large ? "s2" : "s");
kernel_name = large ? "s2" : "s";
break;
case CopyType::Vector:
kernel_name = (large ? "v2" : "v");
kernel_name = large ? "v2" : "v";
break;
case CopyType::General:
kernel_name = "g";
@@ -85,7 +85,10 @@ void copy_gpu_inplace(
}
}
} else {
work_per_thread = get_work_per_thread(in.dtype());
work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
if (work_per_thread > 1) {
kernel_name += "n";
}
}
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
@@ -170,9 +173,10 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
}
out.set_data(allocator::malloc(out.nbytes()));
bool large = out.data_size() > UINT32_MAX;
int work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
auto& d = metal::device(s.device);
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
type_to_name(val) + type_to_name(out);
std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s");
concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out));
auto kernel = get_copy_kernel(d, kernel_name, val, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -180,7 +184,6 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1);
int work_per_thread = get_work_per_thread(val.dtype());
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) {

View File

@@ -1,12 +1,326 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include <regex>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include "mlx/utils.h"
namespace mlx::core::fast {
struct CustomKernelCache {
std::unordered_map<std::string, std::string> libraries;
};
static CustomKernelCache& cache() {
static CustomKernelCache cache_;
return cache_;
};
std::string write_signature(
std::string func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<std::string>& attributes,
const std::vector<CustomKernelShapeInfo>& shape_infos,
bool atomic_outputs) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 16384);
kernel_source += header;
// Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`.
if (!template_args.empty()) {
kernel_source += "template <";
int i = 0;
for (const auto& [name, arg] : template_args) {
std::string param_type;
if (std::holds_alternative<int>(arg)) {
param_type = "int";
} else if (std::holds_alternative<bool>(arg)) {
param_type = "bool";
} else if (std::holds_alternative<Dtype>(arg)) {
param_type = "typename";
}
if (i > 0) {
kernel_source += ", ";
}
kernel_source += param_type;
kernel_source += " ";
kernel_source += name;
i++;
}
kernel_source += ">\n";
}
kernel_source += "[[kernel]] void ";
kernel_source += func_name;
kernel_source += "(\n";
int index = 0;
constexpr int max_constant_array_size = 8;
// Add inputs
for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
auto dtype = get_type_string(arr.dtype());
std::string location =
arr.size() < max_constant_array_size ? "constant" : "device";
std::string ref = arr.ndim() == 0 ? "&" : "*";
kernel_source += " const ";
kernel_source += location;
kernel_source += " ";
kernel_source += dtype;
kernel_source += ref;
kernel_source += " ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]],\n";
index++;
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
kernel_source +=
(" const constant int* " + name + "_shape [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].strides) {
kernel_source +=
(" const constant int64_t* " + name + "_strides [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].ndim) {
kernel_source +=
(" const constant int& " + name + "_ndim [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
}
}
// Add outputs
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source += " device ";
auto type_string = get_type_string(dtype);
if (atomic_outputs) {
kernel_source += "atomic<";
}
kernel_source += type_string;
if (atomic_outputs) {
kernel_source += ">";
}
kernel_source += "* ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]]";
if (index < inputs.size() + output_names.size() - 1 ||
attributes.size() > 0) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
index = 0;
for (const auto& attr : attributes) {
kernel_source += attr;
if (index < attributes.size() - 1) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
kernel_source += source;
kernel_source += "\n}\n";
return kernel_source;
}
std::string write_template(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
std::ostringstream template_def;
template_def << "<";
int i = 0;
for (const auto& [name, arg] : template_args) {
if (i > 0) {
template_def << ", ";
}
if (std::holds_alternative<int>(arg)) {
template_def << std::get<int>(arg);
} else if (std::holds_alternative<bool>(arg)) {
template_def << std::get<bool>(arg);
} else if (std::holds_alternative<Dtype>(arg)) {
template_def << get_type_string(std::get<Dtype>(arg));
}
i++;
}
template_def << ">";
return template_def.str();
}
MetalKernelFunction metal_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header /* = "" */,
bool ensure_row_contiguous /* = true */,
bool atomic_outputs /* = false */) {
if (output_names.empty()) {
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
{"dispatch_quadgroups_per_threadgroup", "uint"},
{"dispatch_simdgroups_per_threadgroup", "uint"},
{"dispatch_threads_per_threadgroup", "uint3"},
{"grid_origin", "uint3"},
{"grid_size", "uint3"},
{"quadgroup_index_in_threadgroup", "uint"},
{"quadgroups_per_threadgroup", "uint"},
{"simdgroup_index_in_threadgroup", "uint"},
{"simdgroups_per_threadgroup", "uint"},
{"thread_execution_width", "uint"},
{"thread_index_in_quadgroup", "uint"},
{"thread_index_in_simdgroup", "uint"},
{"thread_index_in_threadgroup", "uint"},
{"thread_position_in_grid", "uint3"},
{"thread_position_in_threadgroup", "uint3"},
{"threadgroup_position_in_grid", "uint3"},
{"threadgroups_per_grid", "uint3"},
{"threads_per_grid", "uint3"},
{"threads_per_simdgroup", "uint"},
{"threads_per_threadgroup", "uint3"},
};
std::vector<std::string> attributes;
for (const auto& [attr, dtype] : metal_attributes) {
if (source.find(attr) != std::string::npos) {
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
}
}
return [=,
shape_infos = std::move(shape_infos),
attributes = std::move(attributes)](
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
}
std::string kernel_name = "custom_kernel_" + name;
std::string template_def = "";
if (!template_args.empty()) {
std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args);
auto template_hash =
std::regex_replace(template_def, disallowed_chars, "_");
template_hash.pop_back();
kernel_name += "_";
kernel_name += template_hash;
}
std::string kernel_source = write_signature(
kernel_name,
header,
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
attributes,
shape_infos,
atomic_outputs);
if (!template_args.empty()) {
template_def = kernel_name + template_def;
kernel_source += "\ntemplate [[host_name(\"";
kernel_source += kernel_name;
kernel_source += "\")]] [[kernel]] decltype(";
kernel_source += template_def;
kernel_source += ") ";
kernel_source += template_def;
kernel_source += ";\n";
}
if (verbose) {
std::cout << "Generated source code for `" << name << "`:" << std::endl
<< "```" << std::endl
<< kernel_source << std::endl
<< "```" << std::endl;
}
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
std::move(kernel_name),
std::move(kernel_source),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value),
std::move(inputs));
};
}
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@@ -39,9 +353,23 @@ void CustomKernel::eval_gpu(
}
auto& d = metal::device(s.device);
const auto& lib_name = name_;
auto lib =
d.get_library(lib_name, [this] { return metal::utils() + source_; });
{
// Clear kernels from the device library cache if needed
auto& kernel_cache = cache();
if (auto it = kernel_cache.libraries.find(name_);
it != kernel_cache.libraries.end()) {
if (it->second != source_) {
auto& d = metal::device(s.device);
d.clear_library(name_);
it->second = source_;
}
} else {
kernel_cache.libraries.emplace(name_, source_);
}
}
auto lib = d.get_library(name_, [this] { return metal::utils() + source_; });
auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -73,6 +401,16 @@ void CustomKernel::eval_gpu(
}
const auto [tx, ty, tz] = threadgroup_;
auto tg_size = tx * ty * tz;
auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();
if (tg_size > max_tg_size) {
std::ostringstream msg;
msg << "Thread group size (" << tg_size << ") is greater than "
<< " the maximum allowed threads per threadgroup (" << max_tg_size
<< ").";
throw std::invalid_argument(msg.str());
}
const auto [gx, gy, gz] = grid_;
MTL::Size group_dims =
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));

View File

@@ -295,7 +295,7 @@ void CommandEncoder::barrier() {
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
library_map_ = {{"mlx", load_default_library(device_)}};
default_library_ = load_default_library(device_);
arch_ = std::string(device_->architecture()->name()->utf8String());
auto arch = arch_.back();
switch (arch) {
@@ -326,11 +326,11 @@ Device::Device() {
Device::~Device() {
auto pool = new_scoped_memory_pool();
for (auto& k : kernel_map_) {
k.second->release();
}
for (auto& l : library_map_) {
l.second->release();
for (auto& [l, kernel_map] : library_kernels_) {
l->release();
for (auto& [_, k] : kernel_map) {
k->release();
}
}
stream_map_.clear();
device_->release();
@@ -474,13 +474,24 @@ CommandEncoder& Device::get_command_encoder(int index) {
return *stream.encoder;
}
void Device::register_library(
const std::string& lib_name,
const std::string& lib_path) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
library_map_.insert({lib_name, new_lib});
MTL::Library* Device::get_library(
const std::string& name,
const std::string& path /* = "" */) {
{
std::shared_lock rlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
std::unique_lock wlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
auto new_lib = load_library(device_, name, path.c_str());
library_map_.insert({name, new_lib});
return new_lib;
}
MTL::Library* Device::build_library_(const std::string& source_string) {
@@ -649,6 +660,19 @@ MTL::Library* Device::get_library(
return mtl_lib;
}
void Device::clear_library(const std::string& name) {
std::unique_lock wlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
auto kernel_map_it = library_kernels_.find(it->second);
for (auto& [_, kernel] : kernel_map_it->second) {
kernel->release();
}
library_kernels_.erase(kernel_map_it);
it->second->release();
library_map_.erase(it);
}
}
MTL::LinkedFunctions* Device::get_linked_functions_(
const std::vector<MTL::Function*>& funcs) {
if (funcs.empty()) {
@@ -679,6 +703,7 @@ MTL::ComputePipelineState* Device::get_kernel_(
std::unique_lock wlock(kernel_mtx_);
// Try loading again to avoid loading twice
auto& kernel_map_ = library_kernels_[mtl_lib];
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
return it->second;
}
@@ -713,6 +738,7 @@ MTL::ComputePipelineState* Device::get_kernel(
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
auto& kernel_map_ = library_kernels_[mtl_lib];
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
@@ -722,23 +748,11 @@ MTL::ComputePipelineState* Device::get_kernel(
MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
{
// Multiple readers allowed
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
}
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_(lib_name);
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
return get_kernel(
base_name, default_library_, hash_name, func_consts, linked_functions);
}
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {

View File

@@ -187,14 +187,16 @@ class Device {
CommandEncoder& get_command_encoder(int index);
void end_encoding(int index);
void register_library(
const std::string& lib_name,
const std::string& lib_path = "");
MTL::Library* get_library(
const std::string& name,
const std::string& path = "");
MTL::Library* get_library(
const std::string& name,
const std::function<std::string(void)>& builder);
void clear_library(const std::string& name);
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
MTL::Library* mtl_lib,
@@ -204,7 +206,6 @@ class Device {
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& hash_name = "",
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});
@@ -258,10 +259,13 @@ class Device {
std::unordered_map<int32_t, DeviceStream> stream_map_;
std::shared_mutex kernel_mtx_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
std::shared_mutex library_mtx_;
std::unordered_map<std::string, MTL::Library*> library_map_;
MTL::Library* default_library_;
std::unordered_map<
MTL::Library*,
std::unordered_map<std::string, MTL::ComputePipelineState*>>
library_kernels_;
const MTL::ResidencySet* residency_set_{nullptr};
std::string arch_;
int max_ops_per_buffer_;

View File

@@ -41,7 +41,11 @@ MTL::ComputePipelineState* get_unary_kernel(
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::unary_ops(), metal::unary());
kernel_source +=
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1);
if (get_work_per_thread(in_type) > 1) {
kernel_source +=
get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op);
}
kernel_source +=
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source += get_template_definition(
@@ -59,11 +63,8 @@ void append_binary_kernels(
Dtype out_type,
const std::string op,
std::string& kernel_source) {
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
{"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
{"vv", "binary_vv"},
{"vs2", "binary_vs2"},
{"sv2", "binary_sv2"},
{"vv2", "binary_vv2"},
@@ -78,6 +79,22 @@ void append_binary_kernels(
kernel_source +=
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
}
kernel_source += get_template_definition(
"vs_" + lib_name, "binary_vs", in_t, out_t, op, 1);
kernel_source += get_template_definition(
"sv_" + lib_name, "binary_sv", in_t, out_t, op, 1);
kernel_source += get_template_definition(
"vv_" + lib_name, "binary_vv", in_t, out_t, op, 1);
if (get_work_per_thread(in_type) > 1) {
kernel_source += get_template_definition(
"vsn_" + lib_name, "binary_vs", in_t, out_t, op);
kernel_source += get_template_definition(
"svn_" + lib_name, "binary_sv", in_t, out_t, op);
kernel_source += get_template_definition(
"vvn_" + lib_name, "binary_vv", in_t, out_t, op);
}
kernel_source += get_template_definition(
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
kernel_source += get_template_definition(
@@ -133,8 +150,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
auto t_str = get_type_string(type);
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"},
const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{
{"v2", "ternary_v2"},
{"g1large", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"},
@@ -144,6 +160,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
kernel_source +=
get_template_definition(name + "_" + lib_name, func, t_str, op);
}
if (get_work_per_thread(type) > 1) {
kernel_source +=
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op);
}
kernel_source +=
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
kernel_source += get_template_definition(
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
kernel_source += get_template_definition(
@@ -170,15 +193,22 @@ MTL::ComputePipelineState* get_copy_kernel(
kernel_source += metal::copy();
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source +=
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
kernel_source += get_template_definition(
"s_" + lib_name, "copy_s", in_type, out_type, 1);
kernel_source +=
get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type);
kernel_source +=
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
kernel_source += get_template_definition(
"v_" + lib_name, "copy_v", in_type, out_type, 1);
kernel_source +=
get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type);
if (get_work_per_thread(out.dtype()) > 1) {
kernel_source += get_template_definition(
"sn_" + lib_name, "copy_s", in_type, out_type);
kernel_source += get_template_definition(
"vn_" + lib_name, "copy_v", in_type, out_type);
}
kernel_source += get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
kernel_source += get_template_definition(

View File

@@ -17,8 +17,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
}
}
}
@@ -30,8 +36,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
}
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
}
}
}
@@ -43,8 +55,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
}
}
}
@@ -57,8 +75,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
}
}
}
@@ -71,8 +95,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
}
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
}
}
}
@@ -85,8 +115,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
}
}
}

View File

@@ -9,11 +9,16 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \
#define instantiate_binary_work_per_thread(op, tname, itype, otype) \
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) \
#define instantiate_binary_base(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
@@ -26,15 +31,19 @@
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_all(op, int64, int64_t, int64_t)
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_binary_base(op, tname, itype, otype) \
instantiate_binary_work_per_thread(op, tname, itype, otype)
#define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_base(op, int64, int64_t, int64_t)
#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \
@@ -44,7 +53,7 @@
#define instantiate_binary_types(op) \
instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_integer(op) \
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
instantiate_binary_base(op, complex64, complex64_t, complex64_t)\
instantiate_binary_float(op)
#define instantiate_binary_types_bool(op) \
@@ -52,15 +61,15 @@
instantiate_binary_all(op, uint8, uint8_t, bool) \
instantiate_binary_all(op, uint16, uint16_t, bool) \
instantiate_binary_all(op, uint32, uint32_t, bool) \
instantiate_binary_all(op, uint64, uint64_t, bool) \
instantiate_binary_base(op, uint64, uint64_t, bool) \
instantiate_binary_all(op, int8, int8_t, bool) \
instantiate_binary_all(op, int16, int16_t, bool) \
instantiate_binary_all(op, int32, int32_t, bool) \
instantiate_binary_all(op, int64, int64_t, bool) \
instantiate_binary_base(op, int64, int64_t, bool) \
instantiate_binary_all(op, float16, half, bool) \
instantiate_binary_all(op, float32, float, bool) \
instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \
instantiate_binary_all(op, complex64, complex64_t, bool)
instantiate_binary_base(op, complex64, complex64_t, bool)
instantiate_binary_types(Add)
instantiate_binary_types(Divide)
@@ -71,7 +80,7 @@ instantiate_binary_types_bool(Less)
instantiate_binary_types_bool(LessEqual)
instantiate_binary_types_bool(NotEqual)
instantiate_binary_float(LogAddExp)
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
instantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t)
instantiate_binary_types(Maximum)
instantiate_binary_types(Minimum)
instantiate_binary_types(Multiply)
@@ -84,7 +93,7 @@ instantiate_binary_float(ArcTan2)
instantiate_binary_all(NaNEqual, float16, half, bool)
instantiate_binary_all(NaNEqual, float32, float, bool)
instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)
instantiate_binary_all(NaNEqual, complex64, complex64_t, bool)
instantiate_binary_base(NaNEqual, complex64, complex64_t, bool)
instantiate_binary_all(LogicalOr, bool_, bool, bool)
instantiate_binary_all(LogicalAnd, bool_, bool, bool)

View File

@@ -21,10 +21,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
}
@@ -37,10 +45,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0];
d[index + i] = out[1];
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
}
@@ -53,10 +69,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
}
@@ -69,11 +93,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
auto out = Op()(a[0], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
auto out = Op()(a[0], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[0], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
}
@@ -86,11 +118,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
auto out = Op()(a[offset + i], b[0]);
c[offset + i] = out[0];
d[offset + i] = out[1];
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
auto out = Op()(a[offset + i], b[0]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[offset + i], b[0]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
}
@@ -103,11 +143,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
auto out = Op()(a[offset + i], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
auto out = Op()(a[offset + i], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[offset + i], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
}

View File

@@ -7,11 +7,16 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \
#define instantiate_binary_work_per_thread(op, tname, itype, otype) \
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op)
#define instantiate_binary_base(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
@@ -24,22 +29,26 @@
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_binary_base(op, tname, itype, otype) \
instantiate_binary_work_per_thread(op, tname, itype, otype)
#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \
instantiate_binary_all(op, float32, float, float) \
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
#define instantiate_binary_types(op) \
instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_all(op, int64, int64_t, int64_t) \
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
#define instantiate_binary_types(op) \
instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_base(op, int64, int64_t, int64_t) \
instantiate_binary_base(op, complex64, complex64_t, complex64_t) \
instantiate_binary_float(op)
instantiate_binary_types(DivMod) // clang-format on

View File

@@ -1,52 +1,76 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename U, int N = WorkPerThread<T>::n>
template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_s(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
dst[index + i] = static_cast<U>(src[0]);
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
dst[index + i] = static_cast<U>(src[0]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[index + i] = static_cast<U>(src[0]);
}
}
}
template <typename T, typename U, int N = WorkPerThread<T>::n>
template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_v(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
dst[index + i] = static_cast<U>(src[index + i]);
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
dst[index + i] = static_cast<U>(src[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[index + i] = static_cast<U>(src[index + i]);
}
}
}
template <typename T, typename U, int N = WorkPerThread<T>::n>
template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_s2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
}
}
}
template <typename T, typename U, int N = WorkPerThread<T>::n>
template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_v2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]);
}
}
}

View File

@@ -4,9 +4,13 @@
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy_all(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
#define instantiate_copy_work_per_thread(tname, itype, otype) \
instantiate_kernel("sn_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("vn_copy" #tname, copy_v, itype, otype)
#define instantiate_copy_base(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype, 1) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype, 1) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
@@ -18,6 +22,10 @@
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4)
#define instantiate_copy_all(tname, itype, otype) \
instantiate_copy_base(tname, itype, otype) \
instantiate_copy_work_per_thread(tname, itype, otype)
#define instantiate_copy_same(tname, type) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
@@ -42,15 +50,15 @@
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
instantiate_copy_base(itname ##uint64, itype, uint64_t) \
instantiate_copy_all(itname ##int8, itype, int8_t) \
instantiate_copy_all(itname ##int16, itype, int16_t) \
instantiate_copy_all(itname ##int32, itype, int32_t) \
instantiate_copy_all(itname ##int64, itype, int64_t) \
instantiate_copy_base(itname ##int64, itype, int64_t) \
instantiate_copy_all(itname ##float16, itype, half) \
instantiate_copy_all(itname ##float32, itype, float) \
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
instantiate_copy_all(itname ##complex64, itype, complex64_t)
instantiate_copy_base(itname ##complex64, itype, complex64_t)
instantiate_copy_itype(bool_, bool)
instantiate_copy_itype(uint8, uint8_t)

View File

@@ -9,7 +9,41 @@ using namespace metal;
constant bool has_w [[function_constant(20)]];
template <typename T, int N_READS = RMS_N_READS>
template <int N = 1>
inline void initialize_buffer(
threadgroup float* xs,
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
if (simd_group_id == 0) {
for (int i = 0; i < N; i++) {
xs[N * simd_lane_id + i] = 0;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
template <int N = 1>
inline void threadgroup_sum(
thread float* x,
threadgroup float* xs,
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
for (int i = 0; i < N; i++) {
x[i] = simd_sum(x[i]);
}
if (simd_lane_id == 0) {
for (int i = 0; i < N; i++) {
xs[N * simd_group_id + i] = x[i];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N; i++) {
x[i] = xs[N * simd_lane_id + i];
x[i] = simd_sum(x[i]);
}
}
template <typename T, int N_READS = 8>
[[kernel]] void layer_norm_single_row(
const device T* x,
const device T* w,
@@ -23,90 +57,71 @@ template <typename T, int N_READS = RMS_N_READS>
uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
float sumx = 0;
float sumx2 = 0;
float thread_x[N_READS];
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx[SIMD_SIZE];
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
// Initialize the registers and threadgroup memory
float thread_x[N_READS] = {0};
threadgroup float local_buffer[SIMD_SIZE] = {0};
initialize_buffer(local_buffer, simd_lane_id, simd_group_id);
// Advance the pointers
x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
b += b_stride * lid * N_READS;
out += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
// Compute some variables for reading writing etc
const bool safe = lid * N_READS + N_READS <= axis_size;
const int n = axis_size - lid * N_READS;
// Read the inputs
if (safe) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = x[i];
sumx2 += thread_x[i] * thread_x[i];
sumx += thread_x[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = x[i];
sumx2 += thread_x[i] * thread_x[i];
sumx += thread_x[i];
}
for (int i = 0; i < n; i++) {
thread_x[i] = x[i];
}
}
sumx = simd_sum(sumx);
sumx2 = simd_sum(sumx2);
// Initialize shared memory
if (simd_group_id == 0) {
local_sumx[simd_lane_id] = 0;
local_sumx2[simd_lane_id] = 0;
// Compute the mean
float mean = 0;
for (int i = 0; i < N_READS; i++) {
mean += thread_x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
mean /= axis_size;
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sumx[simd_group_id] = sumx;
local_sumx2[simd_group_id] = sumx2;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
sumx = simd_sum(local_sumx[simd_lane_id]);
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
if (simd_lane_id == 0) {
float mean = sumx / axis_size;
float variance = sumx2 / axis_size - mean * mean;
local_mean[0] = mean;
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
// Compute the normalizer
float normalizer = 0;
if (!safe) {
for (int i = n; i < N_READS; i++) {
thread_x[i] = mean;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = local_mean[0];
float normalizer = local_normalizer[0];
for (int i = 0; i < N_READS; i++) {
thread_x[i] -= mean;
normalizer += thread_x[i] * thread_x[i];
}
threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id);
normalizer = metal::precise::rsqrt(normalizer / axis_size + eps);
// Write the outputs
out += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
if (safe) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
thread_x[i] *= normalizer;
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
out[i] =
w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
}
for (int i = 0; i < n; i++) {
thread_x[i] *= normalizer;
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
}
}
}
template <typename T, int N_READS = RMS_N_READS>
template <typename T, int N_READS = 4>
[[kernel]] void layer_norm_looped(
const device T* x,
const device T* w,
@@ -121,71 +136,52 @@ template <typename T, int N_READS = RMS_N_READS>
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
float sumx = 0;
float sumx2 = 0;
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx[SIMD_SIZE];
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
threadgroup float local_buffer[SIMD_SIZE];
initialize_buffer(local_buffer, simd_lane_id, simd_group_id);
x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
b += b_stride * lid * N_READS;
// Compute the mean
float mean = 0;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = x[i + r];
sumx2 += xi * xi;
sumx += xi;
mean += x[i + r];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = x[i + r];
sumx2 += xi * xi;
sumx += xi;
mean += x[i + r];
}
}
}
}
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
mean /= axis_size;
sumx = simd_sum(sumx);
sumx2 = simd_sum(sumx2);
// Initialize shared memory
if (simd_group_id == 0) {
local_sumx[simd_lane_id] = 0;
local_sumx2[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sumx[simd_group_id] = sumx;
local_sumx2[simd_group_id] = sumx2;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
sumx = simd_sum(local_sumx[simd_lane_id]);
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
if (simd_lane_id == 0) {
float mean = sumx / axis_size;
float variance = sumx2 / axis_size - mean * mean;
local_mean[0] = mean;
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
// Compute the normalizer
float normalizer = 0;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float t = x[i + r] - mean;
normalizer += t * t;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float t = x[i + r] - mean;
normalizer += t * t;
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = local_mean[0];
float normalizer = local_normalizer[0];
threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id);
normalizer = metal::precise::rsqrt(normalizer / axis_size + eps);
// Write the outputs
out += gid * size_t(axis_size) + lid * N_READS;
@@ -208,7 +204,7 @@ template <typename T, int N_READS = RMS_N_READS>
}
}
template <typename T, int N_READS = RMS_N_READS>
template <typename T, int N_READS = 8>
[[kernel]] void vjp_layer_norm_single_row(
const device T* x,
const device T* w,
@@ -222,133 +218,96 @@ template <typename T, int N_READS = RMS_N_READS>
uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
// Advance the input pointers
x += gid * size_t(axis_size) + lid * N_READS;
g += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the computation and accumulators
float thread_x[N_READS];
float thread_w[N_READS];
float thread_g[N_READS];
float sumx = 0;
float sumx2 = 0;
float sumwg = 0;
float sumwgx = 0;
// Initialize the registers and threadgroup memory
float thread_x[N_READS] = {0};
float thread_w[N_READS] = {0};
float thread_g[N_READS] = {0};
threadgroup float local_buffer[3 * SIMD_SIZE];
initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id);
constexpr int SIMD_SIZE = 32;
// Compute some variables for reading writing etc
const bool safe = lid * N_READS + N_READS <= axis_size;
const int n = axis_size - lid * N_READS;
threadgroup float local_sumx[SIMD_SIZE];
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_sumwg[SIMD_SIZE];
threadgroup float local_sumwgx[SIMD_SIZE];
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
threadgroup float local_meanwg[1];
threadgroup float local_meanwgx[1];
if (lid * N_READS + N_READS <= axis_size) {
// Read the inputs
if (safe) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = x[i];
thread_w[i] = w[i * w_stride];
thread_g[i] = g[i];
float wg = thread_w[i] * thread_g[i];
sumx += thread_x[i];
sumx2 += thread_x[i] * thread_x[i];
sumwg += wg;
sumwgx += wg * thread_x[i];
thread_w[i] = w[i * w_stride];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = x[i];
thread_w[i] = w[i * w_stride];
thread_g[i] = g[i];
float wg = thread_w[i] * thread_g[i];
sumx += thread_x[i];
sumx2 += thread_x[i] * thread_x[i];
sumwg += wg;
sumwgx += wg * thread_x[i];
}
for (int i = 0; i < n; i++) {
thread_x[i] = x[i];
thread_g[i] = g[i];
thread_w[i] = w[i * w_stride];
}
}
sumx = simd_sum(sumx);
sumx2 = simd_sum(sumx2);
sumwg = simd_sum(sumwg);
sumwgx = simd_sum(sumwgx);
// Initialize shared memory
if (simd_group_id == 0) {
local_sumx[simd_lane_id] = 0;
local_sumx2[simd_lane_id] = 0;
local_sumwg[simd_lane_id] = 0;
local_sumwgx[simd_lane_id] = 0;
// Compute the mean
float mean = 0;
for (int i = 0; i < N_READS; i++) {
mean += thread_x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
mean /= axis_size;
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sumx[simd_group_id] = sumx;
local_sumx2[simd_group_id] = sumx2;
local_sumwg[simd_group_id] = sumwg;
local_sumwgx[simd_group_id] = sumwgx;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
sumx = simd_sum(local_sumx[simd_lane_id]);
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
sumwg = simd_sum(local_sumwg[simd_lane_id]);
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
if (simd_lane_id == 0) {
float mean = sumx / axis_size;
float variance = sumx2 / axis_size - mean * mean;
local_mean[0] = mean;
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
local_meanwg[0] = sumwg / axis_size;
local_meanwgx[0] = sumwgx / axis_size;
// Compute the neccesary scaling factors using the mean
if (!safe) {
for (int i = n; i < N_READS; i++) {
thread_x[i] = mean;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = local_mean[0];
float normalizer = local_normalizer[0];
float meanwg = local_meanwg[0];
float meanwgxc = local_meanwgx[0] - meanwg * mean;
float normalizer2 = normalizer * normalizer;
float factors[3] = {0};
constexpr int meanwg = 0;
constexpr int meanwgxc = 1;
constexpr int normalizer2 = 2;
for (int i = 0; i < N_READS; i++) {
thread_x[i] -= mean;
factors[meanwg] += thread_w[i] * thread_g[i];
factors[meanwgxc] += thread_w[i] * thread_g[i] * thread_x[i];
factors[normalizer2] += thread_x[i] * thread_x[i];
}
threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id);
factors[meanwg] /= axis_size;
factors[meanwgxc] /= axis_size;
factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps);
float normalizer = metal::precise::sqrt(factors[normalizer2]);
// Write the outputs
gx += gid * size_t(axis_size) + lid * N_READS;
gw += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
if (safe) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
thread_x[i] *= normalizer;
gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) -
thread_x[i] * factors[meanwgxc] * factors[normalizer2]);
if (has_w) {
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
if (has_w) {
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
for (int i = 0; i < n; i++) {
thread_x[i] *= normalizer;
gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) -
thread_x[i] * factors[meanwgxc] * factors[normalizer2]);
if (has_w) {
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
}
}
}
template <typename T, int N_READS = RMS_N_READS>
template <typename T, int N_READS = 4>
[[kernel]] void vjp_layer_norm_looped(
const device T* x,
const device T* w,
@@ -363,102 +322,69 @@ template <typename T, int N_READS = RMS_N_READS>
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
// Advance the input pointers
x += gid * size_t(axis_size) + lid * N_READS;
g += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the accumulators
float sumx = 0;
float sumx2 = 0;
float sumwg = 0;
float sumwgx = 0;
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx[SIMD_SIZE];
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_sumwg[SIMD_SIZE];
threadgroup float local_sumwgx[SIMD_SIZE];
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
threadgroup float local_meanwg[1];
threadgroup float local_meanwgx[1];
threadgroup float local_buffer[3 * SIMD_SIZE];
initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id);
// Compute the mean
float mean = 0;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = x[i + r];
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
float wg = wi * gi;
sumx += xi;
sumx2 += xi * xi;
sumwg += wg;
sumwgx += wg * xi;
mean += x[i + r];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = x[i + r];
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
float wg = wi * gi;
sumx += xi;
sumx2 += xi * xi;
sumwg += wg;
sumwgx += wg * xi;
mean += x[i + r];
}
}
}
}
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
mean /= axis_size;
sumx = simd_sum(sumx);
sumx2 = simd_sum(sumx2);
sumwg = simd_sum(sumwg);
sumwgx = simd_sum(sumwgx);
// Initialize shared memory
if (simd_group_id == 0) {
local_sumx[simd_lane_id] = 0;
local_sumx2[simd_lane_id] = 0;
local_sumwg[simd_lane_id] = 0;
local_sumwgx[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sumx[simd_group_id] = sumx;
local_sumx2[simd_group_id] = sumx2;
local_sumwg[simd_group_id] = sumwg;
local_sumwgx[simd_group_id] = sumwgx;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
sumx = simd_sum(local_sumx[simd_lane_id]);
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
sumwg = simd_sum(local_sumwg[simd_lane_id]);
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
if (simd_lane_id == 0) {
float mean = sumx / axis_size;
float variance = sumx2 / axis_size - mean * mean;
local_mean[0] = mean;
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
local_meanwg[0] = sumwg / axis_size;
local_meanwgx[0] = sumwgx / axis_size;
// Compute the neccesary scaling factors using the mean
float factors[3] = {0};
constexpr int meanwg = 0;
constexpr int meanwgxc = 1;
constexpr int normalizer2 = 2;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float t = x[i + r] - mean;
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
float wg = wi * gi;
factors[meanwg] += wg;
factors[meanwgxc] += wg * t;
factors[normalizer2] += t * t;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float t = x[i + r] - mean;
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
float wg = wi * gi;
factors[meanwg] += wg;
factors[meanwgxc] += wg * t;
factors[normalizer2] += t * t;
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = local_mean[0];
float normalizer = local_normalizer[0];
float meanwg = local_meanwg[0];
float meanwgxc = local_meanwgx[0] - meanwg * mean;
float normalizer2 = normalizer * normalizer;
threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id);
factors[meanwg] /= axis_size;
factors[meanwgxc] /= axis_size;
factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps);
float normalizer = metal::precise::sqrt(factors[normalizer2]);
// Write the outputs
gx += gid * size_t(axis_size) + lid * N_READS;
@@ -470,7 +396,8 @@ template <typename T, int N_READS = RMS_N_READS>
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
gx[i + r] = static_cast<T>(
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
normalizer * (wi * gi - factors[meanwg]) -
xi * factors[meanwgxc] * factors[normalizer2]);
if (has_w) {
gw[i + r] = static_cast<T>(gi * xi);
}
@@ -482,7 +409,8 @@ template <typename T, int N_READS = RMS_N_READS>
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
gx[i + r] = static_cast<T>(
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
normalizer * (wi * gi - factors[meanwg]) -
xi * factors[meanwgxc] * factors[normalizer2]);
if (has_w) {
gw[i + r] = static_cast<T>(gi * xi);
}

View File

@@ -9,8 +9,14 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
}
}
}
@@ -23,9 +29,15 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
}
}
}

View File

@@ -8,8 +8,8 @@
#include "mlx/backend/metal/kernels/ternary_ops.h"
#include "mlx/backend/metal/kernels/ternary.h"
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
#define instantiate_ternary_base(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
@@ -20,19 +20,23 @@
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \
instantiate_ternary_base(op, tname, type)
#define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \
instantiate_ternary_all(op, uint8, uint8_t) \
instantiate_ternary_all(op, uint16, uint16_t) \
instantiate_ternary_all(op, uint32, uint32_t) \
instantiate_ternary_all(op, uint64, uint64_t) \
instantiate_ternary_base(op, uint64, uint64_t) \
instantiate_ternary_all(op, int8, int8_t) \
instantiate_ternary_all(op, int16, int16_t) \
instantiate_ternary_all(op, int32, int32_t) \
instantiate_ternary_all(op, int64, int64_t) \
instantiate_ternary_base(op, int64, int64_t) \
instantiate_ternary_all(op, float16, half) \
instantiate_ternary_all(op, float32, float) \
instantiate_ternary_all(op, bfloat16, bfloat16_t) \
instantiate_ternary_all(op, complex64, complex64_t) // clang-format on
instantiate_ternary_base(op, complex64, complex64_t) // clang-format on
instantiate_ternary_types(Select)

View File

@@ -7,8 +7,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
out[index + i] = Op()(in[index + i]);
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
out[index + i] = Op()(in[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
out[index + i] = Op()(in[index + i]);
}
}
}
@@ -19,9 +25,15 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
out[offset + i] = Op()(in[offset + i]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
out[offset + i] = Op()(in[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
out[offset + i] = Op()(in[offset + i]);
}
}
}

View File

@@ -5,31 +5,41 @@
#include "mlx/backend/metal/kernels/unary_ops.h"
#include "mlx/backend/metal/kernels/unary.h"
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel( \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
instantiate_kernel( \
#define instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("vn_" #op #in_tname #out_tname, unary_v, in_type, out_type, op)
#define instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op, 1) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel( \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
instantiate_kernel( \
"gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \
instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type)
#define instantiate_unary_all_same(op, tname, type) \
instantiate_unary_all(op, tname, tname, type, type)
#define instantiate_unary_base_same(op, tname, type) \
instantiate_unary_base(op, tname, tname, type, type)
#define instantiate_unary_float(op) \
instantiate_unary_all_same(op, float16, half) \
instantiate_unary_all_same(op, float32, float) \
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
#define instantiate_unary_int(op) \
instantiate_unary_all_same(op, uint8, uint8_t) \
instantiate_unary_all_same(op, uint16, uint16_t) \
instantiate_unary_all_same(op, uint32, uint32_t) \
instantiate_unary_all_same(op, uint64, uint64_t) \
instantiate_unary_all_same(op, int8, int8_t) \
instantiate_unary_all_same(op, int16, int16_t) \
instantiate_unary_all_same(op, int32, int32_t) \
instantiate_unary_all_same(op, int64, int64_t)
#define instantiate_unary_int(op) \
instantiate_unary_all_same(op, uint8, uint8_t) \
instantiate_unary_all_same(op, uint16, uint16_t) \
instantiate_unary_all_same(op, uint32, uint32_t) \
instantiate_unary_base_same(op, uint64, uint64_t) \
instantiate_unary_all_same(op, int8, int8_t) \
instantiate_unary_all_same(op, int16, int16_t) \
instantiate_unary_all_same(op, int32, int32_t) \
instantiate_unary_base_same(op, int64, int64_t)
#define instantiate_unary_types(op) \
instantiate_unary_all_same(op, bool_, bool) \
@@ -68,29 +78,29 @@ instantiate_unary_float(Tanh)
instantiate_unary_float(Round)
instantiate_unary_int(BitwiseInvert)
instantiate_unary_all_same(Abs, complex64, complex64_t)
instantiate_unary_all_same(ArcCos, complex64, complex64_t)
instantiate_unary_all_same(ArcSin, complex64, complex64_t)
instantiate_unary_all_same(ArcTan, complex64, complex64_t)
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log1p, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t)
instantiate_unary_all_same(Sinh, complex64, complex64_t)
instantiate_unary_all_same(Square, complex64, complex64_t)
instantiate_unary_all_same(Sqrt, complex64, complex64_t)
instantiate_unary_all_same(Rsqrt, complex64, complex64_t)
instantiate_unary_all_same(Tan, complex64, complex64_t)
instantiate_unary_all_same(Tanh, complex64, complex64_t)
instantiate_unary_all_same(Round, complex64, complex64_t)
instantiate_unary_all(Real, complex64, float32, complex64_t, float)
instantiate_unary_all(Imag, complex64, float32, complex64_t, float)
instantiate_unary_base_same(Abs, complex64, complex64_t)
instantiate_unary_base_same(ArcCos, complex64, complex64_t)
instantiate_unary_base_same(ArcSin, complex64, complex64_t)
instantiate_unary_base_same(ArcTan, complex64, complex64_t)
instantiate_unary_base_same(Conjugate, complex64, complex64_t)
instantiate_unary_base_same(Cos, complex64, complex64_t)
instantiate_unary_base_same(Cosh, complex64, complex64_t)
instantiate_unary_base_same(Exp, complex64, complex64_t)
instantiate_unary_base_same(Log, complex64, complex64_t)
instantiate_unary_base_same(Log1p, complex64, complex64_t)
instantiate_unary_base_same(Log2, complex64, complex64_t)
instantiate_unary_base_same(Log10, complex64, complex64_t)
instantiate_unary_base_same(Negative, complex64, complex64_t)
instantiate_unary_base_same(Sign, complex64, complex64_t)
instantiate_unary_base_same(Sin, complex64, complex64_t)
instantiate_unary_base_same(Sinh, complex64, complex64_t)
instantiate_unary_base_same(Square, complex64, complex64_t)
instantiate_unary_base_same(Sqrt, complex64, complex64_t)
instantiate_unary_base_same(Rsqrt, complex64, complex64_t)
instantiate_unary_base_same(Tan, complex64, complex64_t)
instantiate_unary_base_same(Tanh, complex64, complex64_t)
instantiate_unary_base_same(Round, complex64, complex64_t)
instantiate_unary_base(Real, complex64, float32, complex64_t, float)
instantiate_unary_base(Imag, complex64, float32, complex64_t, float)
instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on

View File

@@ -6,7 +6,7 @@
#include <sstream>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/common/matmul.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
@@ -21,69 +21,6 @@ namespace mlx::core {
namespace {
inline auto collapse_batches(const array& a, const array& b) {
// Get and check the shape for the batched dims
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
if (A_bshape != B_bshape) {
std::ostringstream msg;
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ".";
throw std::runtime_error(msg.str());
}
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];
if (batch_shape.empty()) {
batch_shape.push_back(1);
A_batch_stride.push_back(0);
B_batch_stride.push_back(0);
}
return std::make_tuple(batch_shape, A_batch_stride, B_batch_stride);
}
inline auto collapse_batches(const array& a, const array& b, const array& c) {
// Get and check the shape for the batched dims
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
if (A_bshape != B_bshape || A_bshape != C_bshape) {
std::ostringstream msg;
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
throw std::runtime_error(msg.str());
}
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];
auto C_batch_stride = batch_strides[2];
if (batch_shape.empty()) {
batch_shape.push_back(1);
A_batch_stride.push_back(0);
B_batch_stride.push_back(0);
C_batch_stride.push_back(0);
}
return std::make_tuple(
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
}
std::tuple<bool, int64_t, array> check_transpose(
std::vector<array>& copies,
const Stream& s,

View File

@@ -146,7 +146,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
int,
int,
int) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
@@ -207,7 +207,7 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel(
@@ -259,7 +259,7 @@ MTL::ComputePipelineState* get_fft_kernel(
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const std::string&) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_quantized_kernel(
@@ -283,7 +283,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
return d.get_kernel(kernel_name, hash_name, func_consts);
}
} // namespace mlx::core

View File

@@ -172,7 +172,7 @@ void RMSNormVJP::eval_gpu(
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(op_name, hash_name, func_consts);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
@@ -255,12 +255,13 @@ void LayerNorm::eval_gpu(
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
const int simd_size = 32;
const int n_reads = RMS_N_READS;
const int looped_limit = RMS_LOOPED_LIMIT;
int simd_size = 32;
int n_reads = 8;
int looped_limit = 6656;
std::string op_name = "layer_norm";
if (axis_size > looped_limit) {
op_name += "_looped";
n_reads = 4;
}
op_name += type_to_name(out);
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -272,7 +273,13 @@ void LayerNorm::eval_gpu(
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) {
std::ostringstream msg;
msg << "[layer_norm] Threadgroup size " << threadgroup_size
<< " is larger than the maximum allowed threadgroup size "
<< kernel->maxTotalThreadsPerThreadgroup();
throw std::runtime_error(msg.str());
}
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
@@ -372,12 +379,13 @@ void LayerNormVJP::eval_gpu(
g, gb, "sum", plan, {0}, compute_encoder, d, s);
}
const int simd_size = 32;
const int n_reads = RMS_N_READS;
const int looped_limit = RMS_LOOPED_LIMIT;
int simd_size = 32;
int n_reads = 8;
int looped_limit = 8192;
std::string op_name = "vjp_layer_norm";
if (axis_size > looped_limit) {
op_name += "_looped";
n_reads = 4;
}
op_name += type_to_name(gx);
@@ -387,14 +395,20 @@ void LayerNormVJP::eval_gpu(
};
{
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(op_name, hash_name, func_consts);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) {
std::ostringstream msg;
msg << "[vjp_layer_norm] Threadgroup size " << threadgroup_size
<< " is larger than the maximum allowed threadgroup size "
<< kernel->maxTotalThreadsPerThreadgroup();
throw std::runtime_error(msg.str());
}
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);

View File

@@ -73,7 +73,7 @@ void sdpa_full_self_attention_metal(
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
const int NQ = (qL + bq - 1) / bq;
@@ -180,7 +180,7 @@ void sdpa_vector(
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(kname, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments
@@ -281,7 +281,7 @@ void sdpa_vector_2pass(
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(kname, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -369,7 +369,7 @@ bool ScaledDotProductAttention::use_fallback(
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
const bool supports_sdpa_full =
const bool supports_sdpa_full = query_sequence_length > 8 &&
sdpa_full_supported_mask && sdpa_full_supported_head_dim;
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&

View File

@@ -45,7 +45,7 @@ void ternary_op_gpu_inplace(
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > INT32_MAX;
work_per_thread = get_work_per_thread(b.dtype());
work_per_thread = get_work_per_thread(b.dtype(), out.data_size());
}
std::string kernel_name;
if (topt == TernaryOpType::General) {
@@ -60,6 +60,8 @@ void ternary_op_gpu_inplace(
}
} else if (large) {
kernel_name = "v2";
} else if (work_per_thread > 1) {
kernel_name = "vn";
} else {
kernel_name = "v";
}

View File

@@ -1,5 +1,6 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/common/unary.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
@@ -43,8 +44,8 @@ void unary_op_gpu_inplace(
int work_per_thread;
std::string kernel_name;
if (contig) {
work_per_thread = get_work_per_thread(in.dtype());
kernel_name = (large ? "v2" : "v");
work_per_thread = get_work_per_thread(in.dtype(), in.data_size());
kernel_name = (large ? "v2" : (work_per_thread > 1 ? "vn" : "v"));
} else {
work_per_thread = large ? 4 : 1;
kernel_name = "gn" + std::to_string(work_per_thread);
@@ -99,21 +100,7 @@ void unary_op_gpu(
array& out,
const std::string op,
const Stream& s) {
auto& in = inputs[0];
bool contig = in.flags().contiguous;
if (contig) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace(inputs, out, op, s);
}

View File

@@ -72,6 +72,10 @@ void concatenate(std::string& acc, T first, Args... args) {
inline int get_work_per_thread(Dtype dtype) {
return std::max(1, 8 / dtype.size());
}
inline int get_work_per_thread(Dtype dtype, size_t size) {
constexpr size_t wpt_threshold = 1 << 16;
return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size());
}
inline size_t ceildiv(size_t n, size_t m) {
return (n + m - 1) / m;

View File

@@ -2,6 +2,7 @@
#include "mlx/primitives.h"
#include "mlx/distributed/primitives.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#define NO_GPU_MULTI(func) \
@@ -155,6 +156,18 @@ NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)
MetalKernelFunction metal_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool ensure_row_contiguous,
bool atomic_outputs) {
throw std::runtime_error("[metal_kernel] No GPU back-end.");
}
} // namespace fast
namespace distributed {

View File

@@ -266,6 +266,7 @@ struct PrimitiveFactory {
SERIALIZE_PRIMITIVE(Floor),
SERIALIZE_PRIMITIVE(Full),
SERIALIZE_PRIMITIVE(Gather),
SERIALIZE_PRIMITIVE(GatherAxis),
SERIALIZE_PRIMITIVE(GatherMM),
SERIALIZE_PRIMITIVE(Greater),
SERIALIZE_PRIMITIVE(GreaterEqual),
@@ -307,6 +308,7 @@ struct PrimitiveFactory {
"CumMax",
"CumLogaddexp"),
SERIALIZE_PRIMITIVE(Scatter),
SERIALIZE_PRIMITIVE(ScatterAxis),
SERIALIZE_PRIMITIVE(Select),
SERIALIZE_PRIMITIVE(Sigmoid),
SERIALIZE_PRIMITIVE(Sign),

View File

@@ -1,10 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <iostream>
#include <numeric>
#include <regex>
#include "mlx/backend/common/compiled.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include "mlx/ops.h"
@@ -231,13 +228,11 @@ array layer_norm(
const std::vector<array>& inputs) {
auto x = astype(inputs[0], float32, s);
// Should I not be smart here and leave the double mean to simplify()?
auto mu = mean(x, /* axis= */ -1, /* keepdims= */ true, s);
auto mu2 = square(mu, s);
auto x2 = mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s);
auto v = subtract(x2, mu2, s);
auto xc = subtract(x, mu, s);
auto v = mean(square(xc, s), /* axis= */ -1, /* keepdims= */ true, s);
x = multiply(subtract(x, mu, s), rsqrt(add(v, array(eps, float32), s), s));
x = multiply(xc, rsqrt(add(v, array(eps, float32), s), s));
x = astype(x, out_type, s);
// If the LN is affine then transform x according to the weight and bias
@@ -1029,303 +1024,4 @@ std::vector<Shape> AffineQuantize::output_shapes(
}
}
std::string write_signature(
std::string func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<std::string>& attributes,
const std::vector<CustomKernelShapeInfo>& shape_infos,
bool atomic_outputs) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 16384);
kernel_source += header;
// Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`.
if (!template_args.empty()) {
kernel_source += "template <";
int i = 0;
for (const auto& [name, arg] : template_args) {
std::string param_type;
if (std::holds_alternative<int>(arg)) {
param_type = "int";
} else if (std::holds_alternative<bool>(arg)) {
param_type = "bool";
} else if (std::holds_alternative<Dtype>(arg)) {
param_type = "typename";
}
if (i > 0) {
kernel_source += ", ";
}
kernel_source += param_type;
kernel_source += " ";
kernel_source += name;
i++;
}
kernel_source += ">\n";
}
kernel_source += "[[kernel]] void ";
kernel_source += func_name;
kernel_source += "(\n";
int index = 0;
constexpr int max_constant_array_size = 8;
// Add inputs
for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
auto dtype = get_type_string(arr.dtype());
std::string location =
arr.size() < max_constant_array_size ? "constant" : "device";
std::string ref = arr.ndim() == 0 ? "&" : "*";
kernel_source += " const ";
kernel_source += location;
kernel_source += " ";
kernel_source += dtype;
kernel_source += ref;
kernel_source += " ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]],\n";
index++;
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
kernel_source +=
(" const constant int* " + name + "_shape [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].strides) {
kernel_source +=
(" const constant int64_t* " + name + "_strides [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].ndim) {
kernel_source +=
(" const constant int& " + name + "_ndim [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
}
}
// Add outputs
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source += " device ";
auto type_string = get_type_string(dtype);
if (atomic_outputs) {
kernel_source += "atomic<";
}
kernel_source += type_string;
if (atomic_outputs) {
kernel_source += ">";
}
kernel_source += "* ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]]";
if (index < inputs.size() + output_names.size() - 1 ||
attributes.size() > 0) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
index = 0;
for (const auto& attr : attributes) {
kernel_source += attr;
if (index < attributes.size() - 1) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
kernel_source += source;
kernel_source += "\n}\n";
return kernel_source;
}
std::string write_template(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
std::ostringstream template_def;
template_def << "<";
int i = 0;
for (const auto& [name, arg] : template_args) {
if (i > 0) {
template_def << ", ";
}
if (std::holds_alternative<int>(arg)) {
template_def << std::get<int>(arg);
} else if (std::holds_alternative<bool>(arg)) {
template_def << std::get<bool>(arg);
} else if (std::holds_alternative<Dtype>(arg)) {
template_def << get_type_string(std::get<Dtype>(arg));
}
i++;
}
template_def << ">";
return template_def.str();
}
MetalKernelFunction metal_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header /* = "" */,
bool ensure_row_contiguous /* = true */,
bool atomic_outputs /* = false */) {
if (output_names.empty()) {
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
{"dispatch_quadgroups_per_threadgroup", "uint"},
{"dispatch_simdgroups_per_threadgroup", "uint"},
{"dispatch_threads_per_threadgroup", "uint3"},
{"grid_origin", "uint3"},
{"grid_size", "uint3"},
{"quadgroup_index_in_threadgroup", "uint"},
{"quadgroups_per_threadgroup", "uint"},
{"simdgroup_index_in_threadgroup", "uint"},
{"simdgroups_per_threadgroup", "uint"},
{"thread_execution_width", "uint"},
{"thread_index_in_quadgroup", "uint"},
{"thread_index_in_simdgroup", "uint"},
{"thread_index_in_threadgroup", "uint"},
{"thread_position_in_grid", "uint3"},
{"thread_position_in_threadgroup", "uint3"},
{"threadgroup_position_in_grid", "uint3"},
{"threadgroups_per_grid", "uint3"},
{"threads_per_grid", "uint3"},
{"threads_per_simdgroup", "uint"},
{"threads_per_threadgroup", "uint3"},
};
std::vector<std::string> attributes;
for (const auto& [attr, dtype] : metal_attributes) {
if (source.find(attr) != std::string::npos) {
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
}
}
return [=,
shape_infos = std::move(shape_infos),
attributes = std::move(attributes)](
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
}
std::ostringstream func_name;
std::string template_def = "";
std::string hash_key = "";
if (!template_args.empty()) {
std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args);
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
hash_key.pop_back();
}
func_name << "custom_kernel_" << name << hash_key;
std::string kernel_name = func_name.str();
std::string kernel_source = write_signature(
kernel_name,
header,
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
attributes,
shape_infos,
atomic_outputs);
if (!template_args.empty()) {
template_def = kernel_name + template_def;
kernel_source += "\ntemplate [[host_name(\"";
kernel_source += kernel_name;
kernel_source += "\")]] [[kernel]] decltype(";
kernel_source += template_def;
kernel_source += ") ";
kernel_source += template_def;
kernel_source += ";\n";
}
if (verbose) {
std::cout << "Generated source code for `" << name << "`:" << std::endl
<< "```" << std::endl
<< kernel_source << std::endl
<< "```" << std::endl;
}
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
std::move(kernel_name),
std::move(kernel_source),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value),
std::move(inputs));
};
}
} // namespace mlx::core::fast

View File

@@ -719,9 +719,9 @@ class Convolution : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(
kernel_strides_,
padding_lo_,
padding_hi_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
groups_,

View File

@@ -174,11 +174,15 @@ class Module(dict):
new_weights = dict(weights)
curr_weights = dict(tree_flatten(self.parameters()))
if extras := (new_weights.keys() - curr_weights.keys()):
extras = " ".join(extras)
raise ValueError(f"Received parameters not in model: {extras}.")
num_extra = len(extras)
extras = ",\n".join(sorted(extras))
raise ValueError(
f"Received {num_extra} parameters not in model: \n{extras}."
)
if missing := (curr_weights.keys() - new_weights.keys()):
missing = " ".join(missing)
raise ValueError(f"Missing parameters: {missing}.")
num_missing = len(missing)
missing = ",\n".join(sorted(missing))
raise ValueError(f"Missing {num_missing} parameters: \n{missing}.")
for k, v in curr_weights.items():
v_new = new_weights[k]
if not isinstance(v_new, mx.array):
@@ -193,7 +197,7 @@ class Module(dict):
)
if len(weights) != 0:
self.update(tree_unflatten(weights))
self.update(tree_unflatten(weights), strict=False)
return self
def save_weights(self, file: str):
@@ -291,7 +295,7 @@ class Module(dict):
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict) -> Module:
def update(self, parameters: dict, strict: bool = True) -> Module:
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.
@@ -305,7 +309,9 @@ class Module(dict):
Args:
parameters (dict): A complete or partial dictionary of the modules
parameters.
parameters.
strict (bool): If ``True`` checks that ``parameters`` is a
subset of the module's parameters. Default: ``True``.
Returns:
The module instance after updating the parameters.
"""
@@ -317,21 +323,29 @@ class Module(dict):
current_value = dst[k]
new_value = parameters[k]
if isinstance(current_value, mx.array):
if strict and not isinstance(new_value, mx.array):
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
dst[k] = new_value
elif isinstance(current_value, Module):
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
else:
apply(current_value, new_value)
elif strict:
raise ValueError(f'Module does not have parameter named "{k}".')
elif isinstance(parameters, list):
for i in range(len(parameters)):
current_value = dst[i]
new_value = parameters[i]
if isinstance(current_value, mx.array):
if strict and not isinstance(new_value, mx.array):
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
dst[i] = new_value
elif isinstance(current_value, Module):
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
else:
apply(current_value, new_value)
elif strict:
raise ValueError(f"Received invalid type: {type(parameters).__name__}.")
apply(self, parameters)
return self
@@ -359,7 +373,7 @@ class Module(dict):
self.update(self.filter_and_map(filter_fn, map_fn))
return self
def update_modules(self, modules: dict) -> Module:
def update_modules(self, modules: dict, strict: bool = True) -> Module:
"""Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists.
@@ -368,12 +382,14 @@ class Module(dict):
programmatically swapping layers.
The passed in parameters dictionary need not be a full dictionary
similar to :meth:`parameters`. Only the provided locations will be
similar to :meth:`modules`. Only the provided locations will be
updated.
Args:
modules (dict): A complete or partial dictionary of the modules
modules (dict): A complete or partial dictionary of the module's
submodules.
strict (bool): If ``True`` checks that ``modules`` is a
subset of the child modules of this instance. Default: ``True``.
Returns:
The module instance after updating the submodules.
"""
@@ -388,6 +404,14 @@ class Module(dict):
dst[k] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif strict:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(
f'Module does not have sub-module named "{k}".'
)
elif isinstance(modules, list):
for i in range(len(dst)):
current_value = dst[i]
@@ -396,6 +420,12 @@ class Module(dict):
dst[i] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif strict:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
apply(self, modules)
return self

View File

@@ -54,5 +54,9 @@ target_link_libraries(core PRIVATE mlx)
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
if(BUILD_SHARED_LIBS)
target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib)
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib)
else()
target_link_options(core PRIVATE -Wl,-rpath,\$ORIGIN/lib)
endif()
endif()

View File

@@ -6,6 +6,7 @@ import tempfile
import unittest
import mlx.core as mx
import mlx.nn as nn
import mlx_tests
@@ -286,6 +287,65 @@ class TestExportImport(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
f2(mx.array(10), mx.array([5, 10, 20]))
def test_export_scatter_gather(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
def fun(a, b):
return mx.take_along_axis(a, b, axis=0)
x = mx.random.uniform(shape=(4, 4))
y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]])
mx.export_function(path, fun, (x, y))
imported_fun = mx.import_function(path)
expected = fun(x, y)
out = imported_fun(x, y)[0]
self.assertTrue(mx.array_equal(expected, out))
def fun(a, b, c):
return mx.put_along_axis(a, b, c, axis=0)
x = mx.random.uniform(shape=(4, 4))
y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]])
z = mx.random.uniform(shape=(2, 4))
mx.export_function(path, fun, (x, y, z))
imported_fun = mx.import_function(path)
expected = fun(x, y, z)
out = imported_fun(x, y, z)[0]
self.assertTrue(mx.array_equal(expected, out))
def test_export_conv(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
class Model(nn.Module):
def __init__(self):
super().__init__()
self.c1 = nn.Conv2d(
3, 16, kernel_size=3, stride=1, padding=1, bias=False
)
self.c2 = nn.Conv2d(
16, 16, kernel_size=3, stride=2, padding=1, bias=False
)
self.c3 = nn.Conv2d(
16, 16, kernel_size=3, stride=1, padding=2, bias=False
)
def __call__(self, x):
return self.c3(self.c2(self.c1(x)))
model = Model()
mx.eval(model.parameters())
def forward(x):
return model(x)
input_data = mx.random.normal(shape=(4, 32, 32, 3))
mx.export_function(path, forward, input_data)
imported_fn = mx.import_function(path)
out = imported_fn(input_data)[0]
expected = forward(input_data)
self.assertTrue(mx.allclose(expected, out))
if __name__ == "__main__":
unittest.main()

View File

@@ -735,6 +735,41 @@ class TestFast(mlx_tests.MLXTestCase):
)[0]
self.assertEqual(out.item(), 2)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_caching(self):
def call_kernel(a: mx.array, source):
kernel = mx.fast.metal_kernel(
name="my_kernel",
input_names=["inp"],
output_names=["out"],
source=source,
)
return kernel(
inputs=[a],
grid=(a.size, 1, 1),
threadgroup=(a.size, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
stream=mx.gpu,
)[0]
a = mx.random.normal(shape=(32,))
source = """
uint elem = thread_position_in_grid.x;
out[elem] = 0.0;
"""
out = call_kernel(a, source)
self.assertTrue(mx.array_equal(out, mx.zeros_like(out)))
source = """
uint elem = thread_position_in_grid.x;
out[elem] = 1.0;
"""
out = call_kernel(a, source)
self.assertTrue(mx.array_equal(out, mx.ones_like(out)))
if __name__ == "__main__":
unittest.main()

View File

@@ -359,36 +359,6 @@ class TestLinalg(mlx_tests.MLXTestCase):
mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
) # Non-square matrix
def test_lu(self):
with self.assertRaises(ValueError):
mx.linalg.lu(mx.array(0.0), stream=mx.cpu)
with self.assertRaises(ValueError):
mx.linalg.lu(mx.array([0.0, 1.0]), stream=mx.cpu)
with self.assertRaises(ValueError):
mx.linalg.lu(mx.array([[0, 1], [1, 0]]), stream=mx.cpu)
# Test 3x3 matrix
a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]])
P, L, U = mx.linalg.lu(a, stream=mx.cpu)
self.assertTrue(mx.allclose(L[P, :] @ U, a))
# Test batch dimension
a = mx.broadcast_to(a, (5, 5, 3, 3))
P, L, U = mx.linalg.lu(a, stream=mx.cpu)
L = mx.take_along_axis(L, P[..., None], axis=-2)
self.assertTrue(mx.allclose(L @ U, a))
# Test non-square matrix
a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0]])
P, L, U = mx.linalg.lu(a, stream=mx.cpu)
self.assertTrue(mx.allclose(L[P, :] @ U, a))
a = mx.array([[3.0, 1.0], [1.0, 8.0], [9.0, 2.0]])
P, L, U = mx.linalg.lu(a, stream=mx.cpu)
self.assertTrue(mx.allclose(L[P, :] @ U, a))
def test_eigh(self):
tols = {"atol": 1e-5, "rtol": 1e-5}

View File

@@ -219,6 +219,46 @@ class TestBase(mlx_tests.MLXTestCase):
x = mx.zeros((3,))
mx.grad(loss_fn)(model)
def test_update(self):
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
# Updating non-existent parameters
with self.assertRaises(ValueError):
updates = {"layers": [{"value": 0}]}
m.update(updates)
with self.assertRaises(ValueError):
updates = {"layers": ["hello"]}
m.update(updates)
# Wronge type
with self.assertRaises(ValueError):
updates = {"layers": [{"weight": "hi"}]}
m.update(updates)
def test_update_modules(self):
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
# Updating non-existent modules should not be allowed by default
with self.assertRaises(ValueError):
m = m.update_modules({"values": [0, 1]})
# Update wrong types
with self.assertRaises(ValueError):
m = m.update_modules({"layers": [0, 1]})
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.test = mx.array(1.0)
self.list = [mx.array(1.0), mx.array(2.0)]
m = MyModule()
with self.assertRaises(ValueError):
m = m.update_modules({"test": "hi"})
with self.assertRaises(ValueError):
m = m.update_modules({"list": ["hi"]})
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):