mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Compare commits
49 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
5adf185f86 | ||
![]() |
c9a9180584 | ||
![]() |
76831ed83d | ||
![]() |
b3d7b85376 | ||
![]() |
cad5c0241c | ||
![]() |
b8022c578a | ||
![]() |
bc53f8293f | ||
![]() |
c552ff2451 | ||
![]() |
4fda5fbdf9 | ||
![]() |
580776559b | ||
![]() |
a14aaa7c9d | ||
![]() |
a6d780154f | ||
![]() |
6871e2eeb7 | ||
![]() |
8402a2acf4 | ||
![]() |
fddb6933e1 | ||
![]() |
c8b4787e4e | ||
![]() |
2188199ff8 | ||
![]() |
aa07429bad | ||
![]() |
918761a25a | ||
![]() |
a4fc671d3e | ||
![]() |
f5f65ef48c | ||
![]() |
c2dd81a8aa | ||
![]() |
d7e680ffe4 | ||
![]() |
c371baf53a | ||
![]() |
ccf78f566c | ||
![]() |
c9fa68664a | ||
![]() |
c35f4d089a | ||
![]() |
8590c0941e | ||
![]() |
095163b8d1 | ||
![]() |
99c33d011d | ||
![]() |
62fecf3e13 | ||
![]() |
7c4eb5d03e | ||
![]() |
bae9a6b404 | ||
![]() |
004c1d8ef2 | ||
![]() |
7ebb2e0193 | ||
![]() |
9ce77798b1 | ||
![]() |
f8bad60609 | ||
![]() |
5866b3857b | ||
![]() |
1ca616844b | ||
![]() |
2e8cf0b450 | ||
![]() |
24f89173d1 | ||
![]() |
c6a20b427a | ||
![]() |
a5ac9244c4 | ||
![]() |
c763fe1be0 | ||
![]() |
52dc8c8cd5 | ||
![]() |
aede70e81d | ||
![]() |
85a8beb5e4 | ||
![]() |
0bb89e9e5f | ||
![]() |
5685ceb3c7 |
@ -16,6 +16,9 @@ parameters:
|
||||
linux_release:
|
||||
type: boolean
|
||||
default: false
|
||||
cuda_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
@ -104,7 +107,7 @@ jobs:
|
||||
command: |
|
||||
echo "stubs"
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
@ -162,7 +165,7 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
@ -212,6 +215,29 @@ jobs:
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||
|
||||
cuda_build_and_test:
|
||||
machine:
|
||||
image: linux-cuda-12:default
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
python -m venv env
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
pip install -e ".[dev]"
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
python_version:
|
||||
@ -259,7 +285,7 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
@ -318,7 +344,7 @@ jobs:
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
pip install . -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python -m build --wheel
|
||||
@ -332,6 +358,48 @@ jobs:
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
build_cuda_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
extra_env:
|
||||
type: string
|
||||
default: "DEV_RELEASE=1"
|
||||
machine:
|
||||
image: linux-cuda-12:default
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Build wheel
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
python -m venv env
|
||||
source env/bin/activate
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
pip install ".[dev]" -v
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
python -m build --wheel
|
||||
bash python/scripts/repair_cuda.sh
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*.whl
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
workflows:
|
||||
build_and_test:
|
||||
when:
|
||||
@ -348,6 +416,7 @@ workflows:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test
|
||||
- cuda_build_and_test
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
@ -455,6 +524,8 @@ workflows:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
- cuda_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
when:
|
||||
and:
|
||||
@ -598,3 +669,14 @@ workflows:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
cuda_test_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.cuda_release >>
|
||||
jobs:
|
||||
- build_cuda_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
|
@ -1,5 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
|
107
benchmarks/python/conv_unaligned_bench.py
Normal file
107
benchmarks/python/conv_unaligned_bench.py
Normal file
@ -0,0 +1,107 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtype = "float32"
|
||||
shapes = (
|
||||
(4, 32, 32, 21, 3, 3, 128),
|
||||
(4, 32, 32, 21, 3, 3, 37),
|
||||
(4, 32, 32, 370, 3, 3, 370),
|
||||
(4, 32, 32, 370, 7, 7, 128),
|
||||
(2, 320, 640, 21, 7, 7, 21),
|
||||
)
|
||||
for N, H, W, C, kh, kw, O in shapes:
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
@ -1,5 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
@ -18,51 +20,63 @@ def layer_norm(x, w, b, eps):
|
||||
return y
|
||||
|
||||
|
||||
def time_layer_norm():
|
||||
def time_layer_norm(N, dt):
|
||||
L = 1024
|
||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x, w, b):
|
||||
def layer_norm_loop(f, x, w, b):
|
||||
for _ in range(32):
|
||||
x = f(x, w, b)
|
||||
return x
|
||||
|
||||
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
||||
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
||||
|
||||
def layer_norm_grad_loop(g, x, w, b):
|
||||
gx, gw, gb = x, w, b
|
||||
for _ in range(32):
|
||||
gx, gw, gb = g(gx, gw, gb, y)
|
||||
return gx, gw, gb
|
||||
|
||||
time_fn(layer_norm_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
||||
|
||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0,))
|
||||
g2 = mx.grad(f2, argnums=(0,))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x):
|
||||
def layer_norm_grad_x_loop(g, x):
|
||||
gx = x
|
||||
for _ in range(32):
|
||||
gx = g(gx, y)
|
||||
return gx
|
||||
|
||||
time_fn(layer_norm_loop, g1, x)
|
||||
time_fn(layer_norm_loop, g2, x)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x)
|
||||
time_fn(layer_norm_grad_x_loop, g1, x)
|
||||
time_fn(layer_norm_grad_x_loop, g2, x)
|
||||
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_layer_norm()
|
||||
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
||||
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
||||
print(dt, n)
|
||||
time_layer_norm(n, dt)
|
||||
|
@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
||||
Simple Example
|
||||
--------------
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
Every time you make a kernel, a new Metal library is created and possibly
|
||||
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||
:func:`fast.metal_kernel` and then use it many times.
|
||||
|
||||
.. note::
|
||||
We are only required to pass the body of the Metal kernel in ``source``.
|
||||
Only pass the body of the Metal kernel in ``source``. The function
|
||||
signature is generated automatically.
|
||||
|
||||
The full function signature will be generated using:
|
||||
|
||||
@ -78,44 +86,51 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
||||
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
||||
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
||||
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
||||
dimension should be less than or equal to the corresponding grid dimension.
|
||||
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
||||
generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
-------------------
|
||||
|
||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
||||
when indexing.
|
||||
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
||||
is ``True`` by default. This will copy the array inputs if needed
|
||||
before the kernel is launched to ensure that the memory layout is row
|
||||
contiguous. Generally this makes writing the kernel easier, since we don't
|
||||
have to worry about gaps or the ordering of the dims when indexing.
|
||||
|
||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
||||
input array ``a`` if any are present in ``source``.
|
||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
||||
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
||||
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
||||
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
||||
the right elements for each thread.
|
||||
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
||||
relying on a copy from ``ensure_row_contiguous``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
|
||||
return output
|
||||
return output
|
||||
|
||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
||||
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
||||
to write a fast GPU kernel for both the forward and backward passes.
|
||||
|
||||
First we'll implement the forward pass as a fused kernel:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
For a reasonably sized input such as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
|
||||
On an M1 Max, we see a big performance improvement:
|
||||
|
||||
@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
|
||||
Grid Sample VJP
|
||||
---------------
|
||||
|
||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
||||
its custom vjp transform so MLX can differentiate it.
|
||||
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
||||
define its custom vjp transform so MLX can differentiate it.
|
||||
|
||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
||||
requires a few extra :func:`fast.metal_kernel` features:
|
||||
|
||||
* ``init_value=0``
|
||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||
@ -299,128 +316,129 @@ We can then implement the backwards pass as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
|
||||
There's an even larger speed up for the vjp:
|
||||
|
||||
|
@ -397,11 +397,11 @@ below.
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
auto kernel = d.get_kernel(kname.str(), lib);
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@ -30,6 +30,16 @@ MLX is also available on conda-forge. To install MLX with conda do:
|
||||
|
||||
conda install conda-forge::mlx
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
|
||||
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install mlx-cuda
|
||||
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
@ -65,6 +75,8 @@ Build Requirements
|
||||
Python API
|
||||
^^^^^^^^^^
|
||||
|
||||
.. _python install:
|
||||
|
||||
To build and install the MLX python library from source, first, clone MLX from
|
||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||
|
||||
@ -107,6 +119,8 @@ IDE:
|
||||
C++ API
|
||||
^^^^^^^
|
||||
|
||||
.. _cpp install:
|
||||
|
||||
Currently, MLX must be built and installed from source.
|
||||
|
||||
Similarly to the python library, to build and install the MLX C++ library start
|
||||
@ -185,6 +199,7 @@ should point to the path to the built metal library.
|
||||
|
||||
xcrun -sdk macosx --show-sdk-version
|
||||
|
||||
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -213,6 +228,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists across reboots.
|
||||
|
||||
Linux
|
||||
^^^^^
|
||||
|
||||
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||
For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
apt-get update -y
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
|
||||
From here follow the instructions to install either the :ref:`Python <python
|
||||
install>` or :ref:`C++ <cpp install>` APIs.
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
apt-get update -y
|
||||
apt-get -y install cuda-toolkit-12-9
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
|
||||
|
||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||
|
||||
To build the C++ package run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
|
@ -107,6 +107,16 @@ same array:
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
|
||||
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> a[[0, 0]] = mx.array([4, 5])
|
||||
|
||||
The first element of ``a`` could be ``4`` or ``5``.
|
||||
|
||||
Transformations of functions which use in-place updates are allowed and work as
|
||||
expected. For example:
|
||||
|
||||
|
@ -172,11 +172,11 @@ void Axpby::eval_gpu(
|
||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
auto kernel = d.get_kernel(kname.str(), lib);
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@ -55,6 +55,9 @@ endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||
else()
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||
|
@ -1,8 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@ -79,55 +78,6 @@ std::string get_type_string(Dtype d) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
||||
NodeNamer namer;
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
// Fill the input names. This is not really necessary, I just like having A,
|
||||
// B, C, ... as the inputs.
|
||||
for (auto& x : inputs) {
|
||||
namer.get_name(x);
|
||||
}
|
||||
|
||||
// The primitives describing the tape. For unary and binary primitives this
|
||||
// must be enough to describe the full computation.
|
||||
for (auto& a : tape) {
|
||||
// name and type of output
|
||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||
// computation performed
|
||||
a.primitive().print(os);
|
||||
// name of inputs to the function
|
||||
for (auto& inp : a.inputs()) {
|
||||
os << namer.get_name(inp);
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << (is_scalar(x) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
os << kindof(x.dtype()) << x.itemsize();
|
||||
}
|
||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const Shape& shape) {
|
||||
@ -159,8 +109,7 @@ bool compiled_check_contiguity(
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
@ -175,8 +124,7 @@ void compiled_allocate_outputs(
|
||||
// - Donatable
|
||||
// - Not a constant
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
in.is_donatable() && is_constant(i)) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
@ -204,7 +152,7 @@ void compiled_allocate_outputs(
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
is_constant(i)) {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
@ -216,4 +164,74 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
const std::vector<array>& inputs,
|
||||
const array& out,
|
||||
const std::function<bool(size_t)>& is_constant) {
|
||||
const Shape& shape = out.shape();
|
||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||
if (contiguous) {
|
||||
return {true, shape, {}};
|
||||
}
|
||||
|
||||
std::vector<Strides> strides_vec{out.strides()};
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants.
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip scalar inputs.
|
||||
const auto& x = inputs[i];
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
Strides xstrides;
|
||||
size_t j = 0;
|
||||
for (; j < shape.size() - x.ndim(); ++j) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
strides_vec.push_back(std::move(xstrides));
|
||||
}
|
||||
|
||||
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
|
||||
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
|
||||
}
|
||||
|
||||
bool compiled_use_large_index(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
for (const auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
} else {
|
||||
size_t max_size = 0;
|
||||
for (const auto& o : outputs) {
|
||||
max_size = std::max(max_size, o.size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,9 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/primitives.h"
|
||||
@ -14,12 +13,6 @@ inline bool is_static_cast(const Primitive& p) {
|
||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids);
|
||||
|
||||
std::string get_type_string(Dtype d);
|
||||
|
||||
template <typename T>
|
||||
@ -60,8 +53,19 @@ bool compiled_check_contiguity(
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous);
|
||||
|
||||
// Collapse contiguous dims ignoring scalars and constants.
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
const std::vector<array>& inputs,
|
||||
const array& out,
|
||||
const std::function<bool(size_t)>& is_constant);
|
||||
|
||||
// Return whether the kernel should use large index.
|
||||
bool compiled_use_large_index(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -26,7 +26,7 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
// have the same size, then the input buffer can hold the output.
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
return true;
|
||||
} else {
|
||||
|
78
mlx/backend/common/matmul.h
Normal file
78
mlx/backend/common/matmul.h
Normal file
@ -0,0 +1,78 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||
const array& a,
|
||||
const array& b) {
|
||||
// Get and check the shape for the batched dims
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
if (A_bshape != B_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
<< a.shape() << ", B " << b.shape() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||
|
||||
auto a_batch_strides = batch_strides[0];
|
||||
auto b_batch_strides = batch_strides[1];
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
batch_shape.push_back(1);
|
||||
a_batch_strides.push_back(0);
|
||||
b_batch_strides.push_back(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
|
||||
}
|
||||
|
||||
inline std::tuple<Shape, Strides, Strides, Strides>
|
||||
collapse_batches(const array& a, const array& b, const array& c) {
|
||||
// Get and check the shape for the batched dims
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
|
||||
if (A_bshape != B_bshape || A_bshape != C_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||
|
||||
auto A_batch_stride = batch_strides[0];
|
||||
auto B_batch_stride = batch_strides[1];
|
||||
auto C_batch_stride = batch_strides[2];
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
batch_shape.push_back(1);
|
||||
A_batch_stride.push_back(0);
|
||||
B_batch_stride.push_back(0);
|
||||
C_batch_stride.push_back(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(
|
||||
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
26
mlx/backend/common/unary.h
Normal file
26
mlx/backend/common/unary.h
Normal file
@ -0,0 +1,26 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline void set_unary_output_data(const array& in, array& out) {
|
||||
if (in.flags().contiguous) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -209,4 +209,14 @@ Dims get_2d_grid_dims_common(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
||||
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
|
||||
auto gx = (dim0 + bx - 1) / bx;
|
||||
auto gy = (dim1 + by - 1) / by;
|
||||
auto gz = (dim2 + bz - 1) / bz;
|
||||
|
||||
return std::make_pair(
|
||||
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -95,6 +95,9 @@ Dims get_2d_grid_dims_common(
|
||||
const Strides& strides,
|
||||
size_t divisor);
|
||||
|
||||
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
|
||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
|
||||
|
||||
struct ContiguousIterator {
|
||||
inline void step() {
|
||||
int dims = shape_.size();
|
||||
|
@ -146,18 +146,9 @@ inline void build_kernel(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous,
|
||||
int ndim) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
|
||||
// Constants are scalars that are captured by value and cannot change
|
||||
auto is_constant = [&constant_ids](const array& x) {
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
NodeNamer namer;
|
||||
|
||||
#ifdef _MSC_VER
|
||||
@ -170,14 +161,15 @@ inline void build_kernel(
|
||||
|
||||
// Add the input arguments
|
||||
int cnt = 0;
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants from the input list
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
auto tstr = get_type_string(x.dtype());
|
||||
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
||||
<< "];" << std::endl;
|
||||
@ -211,10 +203,11 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
for (auto& x : inputs) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||
print_constant(os, x);
|
||||
os << ";" << std::endl;
|
||||
@ -264,8 +257,9 @@ inline void build_kernel(
|
||||
} else {
|
||||
for (int d = ndim - 1; d >= 0; --d) {
|
||||
// Update pointers
|
||||
for (auto& x : inputs) {
|
||||
if (is_constant(x) || is_scalar(x)) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
if (is_constant(i) || is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
auto& xname = namer.get_name(x);
|
||||
@ -287,65 +281,37 @@ inline void build_kernel(
|
||||
void Compiled::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
if (kernel_lib_.empty()) {
|
||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||
}
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
auto [contiguous, shape, strides] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Collect function input arguments.
|
||||
std::vector<void*> args;
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
const auto& x = inputs[i];
|
||||
encoder.set_input_array(x);
|
||||
args.push_back((void*)x.data<void>());
|
||||
|
||||
if (contiguous || is_scalar(x)) {
|
||||
continue;
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
args.push_back(strides[strides_index++].data());
|
||||
}
|
||||
|
||||
// Broadcast the input to the output shape.
|
||||
std::vector<size_t> xstrides;
|
||||
int j = 0;
|
||||
for (; j < shape.size() - x.ndim(); j++) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
strides.push_back(std::move(xstrides));
|
||||
args.push_back(strides.back().data());
|
||||
}
|
||||
|
||||
// Get the kernel name from the lib
|
||||
int ndim = shape.size();
|
||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||
if (!contiguous) {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
kernel_name += std::to_string(ndim);
|
||||
}
|
||||
|
||||
// Get the function
|
||||
auto fn_ptr = compile(kernel_name, [&]() {
|
||||
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
|
||||
std::ostringstream kernel;
|
||||
kernel << get_kernel_preamble() << std::endl;
|
||||
kernel << "extern \"C\" {" << std::endl;
|
||||
@ -355,7 +321,7 @@ void Compiled::eval_cpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
contiguous,
|
||||
ndim);
|
||||
// Close extern "C"
|
||||
@ -363,26 +329,22 @@ void Compiled::eval_cpu(
|
||||
return kernel.str();
|
||||
});
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
|
||||
for (auto& x : outputs) {
|
||||
args.push_back(x.data<void>());
|
||||
encoder.set_output_array(x);
|
||||
}
|
||||
Shape out_shape;
|
||||
if (!contiguous) {
|
||||
out_shape = outputs[0].shape();
|
||||
args.push_back((void*)out_shape.data());
|
||||
args.push_back((void*)shape.data());
|
||||
} else {
|
||||
args.push_back((void*)outputs[0].data_size());
|
||||
}
|
||||
auto fun = (void (*)(void**))fn_ptr;
|
||||
encoder.dispatch(
|
||||
[fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
|
||||
encoder.dispatch([fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
shape = std::move(shape)]() mutable { fun(args.data()); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -2,32 +2,13 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void set_unary_output_data(const array& in, array& out) {
|
||||
if (in.flags().contiguous) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
||||
for (size_t i = 0; i < shape; i += 1) {
|
||||
|
@ -1,26 +1,84 @@
|
||||
# Filename rules in cuda backend:
|
||||
#
|
||||
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
|
||||
# * Device-only kernel code should be put in kernels/ subdir.
|
||||
# * Files in kernels/ subdir should not include files outside.
|
||||
# * Device-only code should be put in device/ subdir.
|
||||
# * Files in device/ subdir should not include files outside.
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||
|
||||
# Embed kernel sources in binary for JIT compilation.
|
||||
file(
|
||||
GLOB MLX_JIT_SOURCES
|
||||
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh")
|
||||
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
|
||||
add_custom_command(
|
||||
OUTPUT gen/cuda_jit_sources.h
|
||||
COMMAND
|
||||
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
|
||||
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
|
||||
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
|
||||
add_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h)
|
||||
add_dependencies(mlx cuda_jit_sources)
|
||||
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
||||
|
||||
# Enable defining device lambda functions.
|
||||
target_compile_options(mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||
|
||||
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
|
||||
# Explicitly pass this flag to suppress the warning, it is safe to set it to
|
||||
# true but the warning wouldn't be suppressed.
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
||||
target_compile_options(
|
||||
mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--static-global-template-stub=false>")
|
||||
endif()
|
||||
|
||||
# Suppress warning when building for compute capability 7 used by V100.
|
||||
target_compile_options(
|
||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
|
||||
|
||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||
set(MLX_CUDA_ARCHITECTURES
|
||||
@ -51,6 +109,12 @@ target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
||||
|
||||
# Use cublasLt.
|
||||
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
||||
|
||||
# Use NVRTC and driver APIs.
|
||||
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <fmt/format.h>
|
||||
@ -14,11 +15,16 @@ namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
constexpr int page_size = 16384;
|
||||
|
||||
CudaAllocator::CudaAllocator()
|
||||
: buffer_cache_(
|
||||
getpagesize(),
|
||||
page_size,
|
||||
[](CudaBuffer* buf) { return buf->size; },
|
||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||
[this](CudaBuffer* buf) {
|
||||
cuda_free(buf->data);
|
||||
delete buf;
|
||||
}) {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
@ -28,7 +34,14 @@ CudaAllocator::CudaAllocator()
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
// Find available buffer from cache.
|
||||
auto orig_size = size;
|
||||
std::unique_lock lock(mutex_);
|
||||
if (size < page_size) {
|
||||
size = next_power_of_2(size);
|
||||
} else {
|
||||
size = page_size * ((size + page_size - 1) / page_size);
|
||||
}
|
||||
|
||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
if (!buf) {
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
@ -70,7 +83,8 @@ void CudaAllocator::free(Buffer buffer) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
lock.unlock();
|
||||
cuda_free(buf);
|
||||
cuda_free(buf->data);
|
||||
delete buf;
|
||||
}
|
||||
}
|
||||
|
||||
@ -87,6 +101,24 @@ void CudaAllocator::register_this_thread() {
|
||||
allowed_threads_.insert(std::this_thread::get_id());
|
||||
}
|
||||
|
||||
void CudaAllocator::cuda_free(void* buf) {
|
||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
}
|
||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
cudaFree(buf);
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_active_memory() const {
|
||||
return active_memory_;
|
||||
}
|
||||
@ -125,26 +157,6 @@ void CudaAllocator::clear_cache() {
|
||||
buffer_cache_.clear();
|
||||
}
|
||||
|
||||
void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
}
|
||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
cudaFree(buf->data);
|
||||
delete buf;
|
||||
}
|
||||
|
||||
CudaAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
||||
// will not be called on exit and buffers in the cache will be leaked. This
|
||||
|
@ -34,6 +34,9 @@ class CudaAllocator : public allocator::Allocator {
|
||||
// buffers there would result in dead lock.
|
||||
void register_this_thread();
|
||||
|
||||
// Call cudaFree in the safe thread.
|
||||
void cuda_free(void* buf);
|
||||
|
||||
size_t get_active_memory() const;
|
||||
size_t get_peak_memory() const;
|
||||
void reset_peak_memory();
|
||||
@ -47,8 +50,6 @@ class CudaAllocator : public allocator::Allocator {
|
||||
CudaAllocator();
|
||||
friend CudaAllocator& allocator();
|
||||
|
||||
void cuda_free(CudaBuffer* buf);
|
||||
|
||||
std::mutex worker_mutex_;
|
||||
std::unique_ptr<Worker> worker_;
|
||||
std::set<std::thread::id> allowed_threads_;
|
||||
|
188
mlx/backend/cuda/arg_reduce.cu
Normal file
188
mlx/backend/cuda/arg_reduce.cu
Normal file
@ -0,0 +1,188 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T>
|
||||
struct IndexValPair {
|
||||
uint32_t index;
|
||||
T val;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ArgMin {
|
||||
constexpr __device__ T init() {
|
||||
return Limits<T>::max();
|
||||
}
|
||||
|
||||
__device__ IndexValPair<T> operator()(
|
||||
const IndexValPair<T>& best,
|
||||
const IndexValPair<T>& current) {
|
||||
if (best.val > current.val ||
|
||||
(best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ IndexValPair<T>
|
||||
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] < best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset + i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ArgMax {
|
||||
constexpr __device__ T init() {
|
||||
return Limits<T>::min();
|
||||
}
|
||||
|
||||
__device__ IndexValPair<T> operator()(
|
||||
const IndexValPair<T>& best,
|
||||
const IndexValPair<T>& current) {
|
||||
if (best.val < current.val ||
|
||||
(best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ IndexValPair<T>
|
||||
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] > best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset + i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void arg_reduce_general(
|
||||
const T* in,
|
||||
uint32_t* out,
|
||||
size_t size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides in_strides,
|
||||
const __grid_constant__ Strides out_strides,
|
||||
int32_t ndim,
|
||||
int64_t axis_stride,
|
||||
int32_t axis_size) {
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
int64_t index = cg::this_grid().block_rank();
|
||||
if (index >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
|
||||
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
|
||||
|
||||
Op op;
|
||||
T init = op.init();
|
||||
IndexValPair<T> best{0, init};
|
||||
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T vals[N_READS];
|
||||
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
||||
cub::LoadDirectBlocked(
|
||||
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
||||
best = op.reduce_many(best, vals, tid * N_READS);
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
best = BlockReduceT(temp).Reduce(best, op);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
out[out_idx] = best.index;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ArgReduce::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
|
||||
// Prepare the shapes, strides and axis arguments.
|
||||
Shape shape = remove_index(in.shape(), axis_);
|
||||
Strides in_strides = remove_index(in.strides(), axis_);
|
||||
Strides out_strides = out.ndim() == in.ndim()
|
||||
? remove_index(out.strides(), axis_)
|
||||
: out.strides();
|
||||
int64_t axis_stride = in.strides()[axis_];
|
||||
int32_t axis_size = in.shape()[axis_];
|
||||
int32_t ndim = shape.size();
|
||||
|
||||
// ArgReduce.
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block_dims{BLOCK_DIM, 1, 1};
|
||||
auto kernel = &cu::arg_reduce_general<
|
||||
InType,
|
||||
cu::ArgMax<InType>,
|
||||
BLOCK_DIM,
|
||||
N_READS>;
|
||||
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||
kernel = &cu::arg_reduce_general<
|
||||
InType,
|
||||
cu::ArgMin<InType>,
|
||||
BLOCK_DIM,
|
||||
N_READS>;
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(),
|
||||
out.data<uint32_t>(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(in_strides),
|
||||
const_param(out_strides),
|
||||
ndim,
|
||||
axis_stride,
|
||||
axis_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
150
mlx/backend/cuda/bin2h.cmake
Normal file
150
mlx/backend/cuda/bin2h.cmake
Normal file
@ -0,0 +1,150 @@
|
||||
# Based on: https://github.com/sivachandran/cmake-bin2h
|
||||
#
|
||||
# Copyright 2020 Sivachandran Paramasivam
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
include(CMakeParseArguments)
|
||||
|
||||
# Function to wrap a given string into multiple lines at the given column
|
||||
# position.
|
||||
#
|
||||
# Parameters:
|
||||
#
|
||||
# * VARIABLE - The name of the CMake variable holding the string.
|
||||
# * AT_COLUMN - The column position at which string will be wrapped.
|
||||
function(WRAP_STRING)
|
||||
set(oneValueArgs VARIABLE AT_COLUMN)
|
||||
cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN})
|
||||
|
||||
string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength)
|
||||
math(EXPR offset "0")
|
||||
|
||||
while(stringLength GREATER 0)
|
||||
if(stringLength GREATER ${WRAP_STRING_AT_COLUMN})
|
||||
math(EXPR length "${WRAP_STRING_AT_COLUMN}")
|
||||
else()
|
||||
math(EXPR length "${stringLength}")
|
||||
endif()
|
||||
|
||||
string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line)
|
||||
set(lines "${lines}\n ${line}")
|
||||
|
||||
math(EXPR stringLength "${stringLength} - ${length}")
|
||||
math(EXPR offset "${offset} + ${length}")
|
||||
endwhile()
|
||||
|
||||
set(${WRAP_STRING_VARIABLE}
|
||||
"${lines}"
|
||||
PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Function to embed contents of a file as byte array in C/C++ header file(.h).
|
||||
# The header file will contain a byte array and integer variable holding the
|
||||
# size of the array.
|
||||
#
|
||||
# Parameters:
|
||||
#
|
||||
# * SOURCE_FILES - The paths of source files whose contents will be embedded in
|
||||
# the header file.
|
||||
# * VARIABLE_NAME - The name of the variable for the byte array. The string
|
||||
# "_SIZE" will be append to this name and will be used a variable name for
|
||||
# size variable.
|
||||
# * HEADER_FILE - The path of header file.
|
||||
# * APPEND - If specified appends to the header file instead of overwriting it
|
||||
# * HEADER_NAMESPACE - The namespace, where the array should be located in.
|
||||
# * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte
|
||||
# array.
|
||||
#
|
||||
# Usage:
|
||||
#
|
||||
# bin2h(SOURCE_FILE "Logo.png" HEADER_FILE "Logo.h" VARIABLE_NAME "LOGO_PNG")
|
||||
function(BIN2H)
|
||||
set(options APPEND NULL_TERMINATE)
|
||||
set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE)
|
||||
set(multiValueArgs SOURCE_FILES)
|
||||
cmake_parse_arguments(BIN2H "${options}" "${oneValueArgs}"
|
||||
"${multiValueArgs}" ${ARGN})
|
||||
|
||||
set(arrayDefinition "")
|
||||
foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES)
|
||||
# get filename without extension
|
||||
get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE)
|
||||
# convert the filename to a valid C identifier
|
||||
string(MAKE_C_IDENTIFIER "${FILE_NAME_WE}" VALID_FILE_NAME)
|
||||
|
||||
# reads source file contents as hex string
|
||||
file(READ ${SOURCE_FILE} hexString HEX)
|
||||
|
||||
# append null
|
||||
if(BIN2H_NULL_TERMINATE)
|
||||
string(APPEND hexString "00")
|
||||
endif()
|
||||
|
||||
# wraps the hex string into multiple lines
|
||||
wrap_string(VARIABLE hexString AT_COLUMN 24)
|
||||
|
||||
# strip the © in source code
|
||||
string(REGEX REPLACE "c2a9" "2020" arrayValues ${hexString})
|
||||
|
||||
string(REGEX REPLACE "([0-9a-f][0-9a-f])" " 0x\\1," arrayValues
|
||||
${arrayValues})
|
||||
|
||||
# make a full variable name for the array
|
||||
set(FULL_VARIABLE_NAME "${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}")
|
||||
|
||||
# declares byte array and the length variables
|
||||
string(APPEND arrayDefinition
|
||||
"constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\n};\n\n")
|
||||
endforeach()
|
||||
|
||||
# add namespace wrapper if defined
|
||||
if(DEFINED BIN2H_HEADER_NAMESPACE)
|
||||
set(namespaceStart "namespace ${BIN2H_HEADER_NAMESPACE} {")
|
||||
set(namespaceEnd "} // namespace ${BIN2H_HEADER_NAMESPACE}")
|
||||
set(declarations "${namespaceStart}\n\n${arrayDefinition}${namespaceEnd}\n")
|
||||
endif()
|
||||
|
||||
set(arrayIncludes "#pragma once")
|
||||
string(PREPEND declarations "${arrayIncludes}\n\n")
|
||||
|
||||
if(BIN2H_APPEND)
|
||||
file(APPEND ${BIN2H_HEADER_FILE} "${declarations}")
|
||||
else()
|
||||
file(WRITE ${BIN2H_HEADER_FILE} "${declarations}")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# ----------------------------- CLI args -----------------------------
|
||||
|
||||
string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
|
||||
foreach(source ${MLX_JIT_SOURCES_LIST})
|
||||
list(APPEND MLX_JIT_SOURCES_ABS "${MLX_SOURCE_ROOT}/${source}")
|
||||
endforeach()
|
||||
|
||||
bin2h(
|
||||
SOURCE_FILES
|
||||
${MLX_JIT_SOURCES_ABS}
|
||||
NULL_TERMINATE
|
||||
VARIABLE_NAME
|
||||
"jit_source"
|
||||
HEADER_NAMESPACE
|
||||
"mlx::core"
|
||||
HEADER_FILE
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h")
|
292
mlx/backend/cuda/binary.cu
Normal file
292
mlx/backend/cuda/binary.cu
Normal file
@ -0,0 +1,292 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[0], b[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[0], b[index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[index], b[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[index], b[index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void binary_g_nd(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), a_strides.data(), b_strides.data());
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_g(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides a_strides,
|
||||
const __grid_constant__ Strides b_strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_binary_op() {
|
||||
if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||
|
||||
std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||
|
||||
std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||
|
||||
std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {
|
||||
return std::is_same_v<In, Out>;
|
||||
}
|
||||
if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||
|
||||
std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||
|
||||
std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {
|
||||
return std::is_same_v<Out, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {
|
||||
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, NaNEqual>) {
|
||||
return std::is_same_v<Out, bool> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogAddExp>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, ArcTan2>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
|
||||
std::is_same_v<Op, BitwiseXor>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||
!std::is_same_v<In, bool>;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
bool large = a.data_size() > INT32_MAX ||
|
||||
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel =
|
||||
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(a_strides),
|
||||
const_param<NDIM>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size());
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||
op,
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
#define BINARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = out.primitive().stream(); \
|
||||
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
BINARY_GPU(Remainder)
|
||||
BINARY_GPU(Greater)
|
||||
BINARY_GPU(GreaterEqual)
|
||||
BINARY_GPU(Less)
|
||||
BINARY_GPU(LessEqual)
|
||||
BINARY_GPU(LogicalAnd)
|
||||
BINARY_GPU(LogicalOr)
|
||||
BINARY_GPU(LogAddExp)
|
||||
BINARY_GPU(Maximum)
|
||||
BINARY_GPU(Minimum)
|
||||
BINARY_GPU(Multiply)
|
||||
BINARY_GPU(NotEqual)
|
||||
BINARY_GPU(Power)
|
||||
BINARY_GPU(Subtract)
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
if (equal_nan_) {
|
||||
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
|
||||
} else {
|
||||
binary_op_gpu<cu::Equal>(inputs, out, op, s);
|
||||
}
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op_gpu<cu::LeftShift>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op_gpu<cu::RightShift>(inputs, out, op, s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
248
mlx/backend/cuda/binary_two.cu
Normal file
248
mlx/backend/cuda/binary_two.cu
Normal file
@ -0,0 +1,248 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[0], b[0]);
|
||||
out_a[0] = out[0];
|
||||
out_b[0] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[0], b[index]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[index], b[0]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[index], b[index]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void binary_g_nd(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out_a,
|
||||
Out* out_b,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), a_strides.data(), b_strides.data());
|
||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_g(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out_a,
|
||||
Out* out_b,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides a_strides,
|
||||
const __grid_constant__ Strides b_strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_binary_op() {
|
||||
if (std::is_same_v<Op, DivMod>) {
|
||||
return std::is_same_v<In, Out> &&
|
||||
(std::is_integral_v<Out> || is_floating_v<Out>);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
auto& out_a = outputs[0];
|
||||
auto& out_b = outputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
if (out_a.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
|
||||
MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, {
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out_a);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
bool large = a.data_size() > INT32_MAX ||
|
||||
b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel =
|
||||
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out_a, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(a_strides),
|
||||
const_param<NDIM>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out_a, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel,
|
||||
out_a.data_size(),
|
||||
out_a.shape(),
|
||||
out_a.strides(),
|
||||
LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.data_size());
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||
op,
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(out_a.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
void DivMod::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("DivMod::eval_gpu");
|
||||
auto& s = outputs[0].primitive().stream();
|
||||
binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
230
mlx/backend/cuda/compiled.cpp
Normal file
230
mlx/backend/cuda/compiled.cpp
Normal file
@ -0,0 +1,230 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
struct FusedKernelBuilder {
|
||||
std::string os;
|
||||
const std::string& kernel_name;
|
||||
const std::vector<array>& inputs;
|
||||
const std::vector<array>& outputs;
|
||||
const std::vector<array>& tape;
|
||||
const std::function<bool(size_t)>& is_constant;
|
||||
|
||||
void build(const char* name, bool contiguous) {
|
||||
NodeNamer namer;
|
||||
|
||||
// Function parameters.
|
||||
std::vector<std::string> params;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
params.push_back(
|
||||
fmt::format("const {}* {}", dtype_to_cuda_type(x.dtype()), xname));
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
params.push_back(fmt::format(
|
||||
"const __grid_constant__ cuda::std::array<int64_t, NDIM> {}_strides",
|
||||
xname));
|
||||
}
|
||||
}
|
||||
for (const auto& x : outputs) {
|
||||
params.push_back(fmt::format(
|
||||
"{}* {}", dtype_to_cuda_type(x.dtype()), namer.get_name(x)));
|
||||
}
|
||||
if (!contiguous) {
|
||||
params.push_back(
|
||||
"const __grid_constant__ cuda::std::array<int32_t, NDIM> shape");
|
||||
}
|
||||
params.push_back("IdxT size");
|
||||
|
||||
// Build function signature.
|
||||
if (contiguous) {
|
||||
os += "template <typename IdxT = uint32_t>\n";
|
||||
} else {
|
||||
os += "template <int NDIM, typename IdxT = uint32_t>\n";
|
||||
}
|
||||
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
||||
for (size_t i = 0; i < params.size(); ++i) {
|
||||
os += " ";
|
||||
os += params[i];
|
||||
if (i != params.size() - 1) {
|
||||
os += ",\n";
|
||||
}
|
||||
}
|
||||
os += ") {\n";
|
||||
|
||||
// Index.
|
||||
os +=
|
||||
" IdxT index = cg::this_grid().thread_rank();\n"
|
||||
" if (index >= size) {\n"
|
||||
" return;\n"
|
||||
" }\n";
|
||||
|
||||
// Read inputs.
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
std::string type = dtype_to_cuda_type(x.dtype());
|
||||
std::string value;
|
||||
if (is_constant(i)) {
|
||||
std::ostringstream ss;
|
||||
print_constant(ss, x);
|
||||
value = fmt::format("static_cast<{}>({})", type, ss.str());
|
||||
} else if (is_scalar(x)) {
|
||||
value = fmt::format("{}[0]", xname);
|
||||
} else if (contiguous) {
|
||||
value = fmt::format("{}[index]", xname);
|
||||
} else {
|
||||
std::string index = fmt::format(
|
||||
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
|
||||
xname);
|
||||
value = fmt::format("{}[{}]", xname, index);
|
||||
}
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
}
|
||||
|
||||
// Write tape.
|
||||
for (const auto& x : tape) {
|
||||
const std::string& xname = namer.get_name(x);
|
||||
std::string type = dtype_to_cuda_type(x.dtype());
|
||||
std::string value;
|
||||
if (is_static_cast(x.primitive())) {
|
||||
value = fmt::format(
|
||||
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
|
||||
} else {
|
||||
std::ostringstream ss;
|
||||
x.primitive().print(ss);
|
||||
value = ss.str();
|
||||
value += "{}(";
|
||||
for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
|
||||
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));
|
||||
}
|
||||
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
|
||||
}
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
}
|
||||
|
||||
// Write output.
|
||||
for (const auto& x : outputs) {
|
||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||
}
|
||||
|
||||
os += "}\n";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cu
|
||||
|
||||
constexpr const char* g_jit_includes = R"(
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
#define inf cuda::std::numeric_limits<float>::infinity()
|
||||
)";
|
||||
|
||||
void Compiled::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("Compiled::eval_gpu");
|
||||
auto& s = stream();
|
||||
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
|
||||
// Build source code.
|
||||
cu::FusedKernelBuilder builder{
|
||||
g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_};
|
||||
builder.os +=
|
||||
"namespace mlx::core::cu {\n\n"
|
||||
"namespace cg = cooperative_groups;\n\n";
|
||||
builder.build("_contiguous", true);
|
||||
builder.os += "\n";
|
||||
builder.build("_strided", false);
|
||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||
// Build kernel names.
|
||||
std::vector<std::string> kernel_names = {
|
||||
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
|
||||
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
|
||||
};
|
||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
|
||||
kernel_names.push_back(
|
||||
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
|
||||
}
|
||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||
});
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
auto [contiguous, shape, strides_vec] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Whether to use large index.
|
||||
bool large = compiled_use_large_index(inputs, outputs, contiguous);
|
||||
|
||||
// Put inputs.
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
}
|
||||
const auto& x = inputs[i];
|
||||
mod.append_arg(x);
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
mod.append_arg(strides_vec[strides_index++]);
|
||||
}
|
||||
}
|
||||
|
||||
// Put outputs.
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
for (auto& x : outputs) {
|
||||
mod.append_arg(x);
|
||||
}
|
||||
|
||||
// Put shape and size.
|
||||
if (!contiguous) {
|
||||
mod.append_arg(shape);
|
||||
}
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(outputs[0].data_size());
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(outputs[0].data_size());
|
||||
}
|
||||
|
||||
// Launch kernel.
|
||||
const char* index_type = large ? "int64_t" : "uint32_t";
|
||||
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
||||
if (contiguous) {
|
||||
kernel_name += fmt::format("_contiguous<{}>", index_type);
|
||||
} else {
|
||||
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
|
||||
}
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
for (const auto& out : outputs) {
|
||||
encoder.set_output_array(out);
|
||||
}
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, outputs[0], large);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1,26 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& data_shape,
|
||||
const Strides& strides_in_pre,
|
||||
const Strides& strides_out_pre,
|
||||
int64_t inp_offset,
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
|
||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
||||
throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend.");
|
||||
}
|
||||
|
||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
throw std::runtime_error("fill_gpu not implemented in CUDA backend.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
87
mlx/backend/cuda/copy.cu
Normal file
87
mlx/backend/cuda/copy.cu
Normal file
@ -0,0 +1,87 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
CopyType ctype,
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_offset_in,
|
||||
const std::optional<array>& dynamic_offset_out) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
||||
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
||||
return;
|
||||
}
|
||||
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
auto [shape_collapsed, strides_vec] = collapse_contiguous_dims(
|
||||
shape, std::vector{strides_in, strides_out}, INT32_MAX);
|
||||
if (ctype == CopyType::General) {
|
||||
copy_general_input(
|
||||
encoder,
|
||||
ctype,
|
||||
in,
|
||||
out,
|
||||
offset_in,
|
||||
offset_out,
|
||||
shape_collapsed,
|
||||
strides_vec[0]);
|
||||
} else {
|
||||
if (dynamic_offset_in || dynamic_offset_out) {
|
||||
copy_general_dynamic(
|
||||
encoder,
|
||||
ctype,
|
||||
in,
|
||||
out,
|
||||
offset_in,
|
||||
offset_out,
|
||||
shape_collapsed,
|
||||
strides_vec[0],
|
||||
strides_vec[1],
|
||||
dynamic_offset_in ? *dynamic_offset_in : array(0, int64),
|
||||
dynamic_offset_out ? *dynamic_offset_out : array(0, int64));
|
||||
} else {
|
||||
copy_general(
|
||||
encoder,
|
||||
ctype,
|
||||
in,
|
||||
out,
|
||||
offset_in,
|
||||
offset_out,
|
||||
shape_collapsed,
|
||||
strides_vec[0],
|
||||
strides_vec[1]);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void fill_gpu(const array& in, array& out, const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
64
mlx/backend/cuda/copy/copy.cuh
Normal file
64
mlx/backend/cuda/copy/copy.cuh
Normal file
@ -0,0 +1,64 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
||||
using InType = cuda_type_t<CTYPE_IN>; \
|
||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
||||
__VA_ARGS__; \
|
||||
}); \
|
||||
})
|
||||
|
||||
void copy_contiguous(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out);
|
||||
|
||||
void copy_general(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out);
|
||||
|
||||
void copy_general_dynamic(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out,
|
||||
const array& dynamic_offset_in,
|
||||
const array& dynamic_offset_out);
|
||||
|
||||
void copy_general_input(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in);
|
||||
|
||||
} // namespace mlx::core
|
57
mlx/backend/cuda/copy/copy_contiguous.cu
Normal file
57
mlx/backend/cuda/copy/copy_contiguous.cu
Normal file
@ -0,0 +1,57 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
__global__ void copy_s(const In* in, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = CastOp<In, Out>{}(in[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
__global__ void copy_v(const In* in, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = CastOp<In, Out>{}(in[index]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void copy_contiguous(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t in_offset,
|
||||
int64_t out_offset) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = cu::copy_s<InType, OutType, IdxT>;
|
||||
if (ctype == CopyType::Vector) {
|
||||
kernel = cu::copy_v<InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>() + in_offset,
|
||||
out.data<OutType>() + out_offset,
|
||||
out.data_size());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
100
mlx/backend/cuda/copy/copy_general.cu
Normal file
100
mlx/backend/cuda/copy/copy_general.cu
Normal file
@ -0,0 +1,100 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void copy_gg_nd(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), strides_in.data(), strides_out.data());
|
||||
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
__global__ void copy_gg(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides_in,
|
||||
const __grid_constant__ Strides strides_out,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc_4d(
|
||||
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void copy_general(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
size_t data_size = 1;
|
||||
for (auto& s : shape)
|
||||
data_size *= s;
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
data_size,
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in),
|
||||
const_param<NDIM>(strides_out));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
data_size,
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
105
mlx/backend/cuda/copy/copy_general_dynamic.cu
Normal file
105
mlx/backend/cuda/copy/copy_general_dynamic.cu
Normal file
@ -0,0 +1,105 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void copy_gg_dynamic_nd(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out,
|
||||
const int64_t* offset_in,
|
||||
const int64_t* offset_out) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), strides_in.data(), strides_out.data());
|
||||
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
__global__ void copy_gg_dynamic(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides_in,
|
||||
const __grid_constant__ Strides strides_out,
|
||||
int ndim,
|
||||
const int64_t* offset_in,
|
||||
const int64_t* offset_out) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc_4d(
|
||||
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void copy_general_dynamic(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out,
|
||||
const array& dynamic_offset_in,
|
||||
const array& dynamic_offset_out) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel = cu::copy_gg_dynamic_nd<InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in),
|
||||
const_param<NDIM>(strides_out),
|
||||
dynamic_offset_in.data<int64_t>(),
|
||||
dynamic_offset_out.data<int64_t>());
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
ndim,
|
||||
dynamic_offset_in.data<int64_t>(),
|
||||
dynamic_offset_out.data<int64_t>());
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
88
mlx/backend/cuda/copy/copy_general_input.cu
Normal file
88
mlx/backend/cuda/copy/copy_general_input.cu
Normal file
@ -0,0 +1,88 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void copy_g_nd(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data());
|
||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
__global__ void copy_g(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides_in,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim);
|
||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void copy_general_input(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel = cu::copy_g_nd<InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_g<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
11
mlx/backend/cuda/cuda.cpp
Normal file
11
mlx/backend/cuda/cuda.cpp
Normal file
@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
10
mlx/backend/cuda/cuda.h
Normal file
10
mlx/backend/cuda/cuda.h
Normal file
@ -0,0 +1,10 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
/* Check if the CUDA backend is available. */
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::cu
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <future>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -34,14 +35,26 @@ CommandEncoder& DeviceStream::get_encoder() {
|
||||
}
|
||||
|
||||
Device::Device(int device) : device_(device) {
|
||||
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||
&compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_));
|
||||
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||
&compute_capability_minor_, cudaDevAttrComputeCapabilityMinor, device_));
|
||||
// Validate the requirements of device.
|
||||
int attr = 0;
|
||||
cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_);
|
||||
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||
&attr, cudaDevAttrConcurrentManagedAccess, device_));
|
||||
if (attr != 1) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Device {} does not support synchronization in managed memory.",
|
||||
device_));
|
||||
}
|
||||
// The cublasLt handle is used by matmul.
|
||||
make_current();
|
||||
cublasLtCreate(<_);
|
||||
}
|
||||
|
||||
Device::~Device() {
|
||||
cublasLtDestroy(lt_);
|
||||
}
|
||||
|
||||
void Device::make_current() {
|
||||
@ -95,6 +108,16 @@ void CommandEncoder::commit() {
|
||||
worker_.commit(stream_.last_cuda_stream());
|
||||
}
|
||||
|
||||
void CommandEncoder::synchronize() {
|
||||
stream().synchronize();
|
||||
auto p = std::make_shared<std::promise<void>>();
|
||||
std::future<void> f = p->get_future();
|
||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||
worker_.end_batch();
|
||||
commit();
|
||||
f.wait();
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device device) {
|
||||
static std::unordered_map<int, Device> devices;
|
||||
auto it = devices.find(device.index);
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
|
||||
#include <unordered_map>
|
||||
@ -46,6 +47,7 @@ class DeviceStream {
|
||||
class Device {
|
||||
public:
|
||||
explicit Device(int device);
|
||||
~Device();
|
||||
|
||||
Device(const Device&) = delete;
|
||||
Device& operator=(const Device&) = delete;
|
||||
@ -58,9 +60,21 @@ class Device {
|
||||
int cuda_device() const {
|
||||
return device_;
|
||||
}
|
||||
int compute_capability_major() const {
|
||||
return compute_capability_major_;
|
||||
}
|
||||
int compute_capability_minor() const {
|
||||
return compute_capability_minor_;
|
||||
}
|
||||
cublasLtHandle_t lt_handle() const {
|
||||
return lt_;
|
||||
}
|
||||
|
||||
private:
|
||||
int device_;
|
||||
int compute_capability_major_;
|
||||
int compute_capability_minor_;
|
||||
cublasLtHandle_t lt_;
|
||||
std::unordered_map<int, DeviceStream> streams_;
|
||||
};
|
||||
|
||||
@ -109,6 +123,9 @@ class CommandEncoder {
|
||||
return has_gpu_work_;
|
||||
}
|
||||
|
||||
// Wait until kernels and completion handlers are finished
|
||||
void synchronize();
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
DeviceStream& stream_;
|
||||
|
72
mlx/backend/cuda/device/atomic_ops.cuh
Normal file
72
mlx/backend/cuda/device/atomic_ops.cuh
Normal file
@ -0,0 +1,72 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
|
||||
#include <cuda/atomic>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_add(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref += val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_prod(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
T old = ref.load();
|
||||
while (!ref.compare_exchange_strong(old, old * val)) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_max(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref.fetch_max(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_min(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref.fetch_min(val);
|
||||
}
|
||||
|
||||
// Somehow cuda::atomic_ref does not provide atomic add for following types.
|
||||
template <typename T>
|
||||
inline __device__ void atomic_add_general(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
T old = ref.load();
|
||||
while (!ref.compare_exchange_strong(old, old + val)) {
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(__half* out, __half val) {
|
||||
atomicAdd(out, val);
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(cuComplex* out, cuComplex val) {
|
||||
#if __CUDA_ARCH__ < 900
|
||||
atomic_add_general(out, val);
|
||||
#else
|
||||
atomicAdd(out, val);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
||||
#if __CUDA_ARCH__ < 800
|
||||
#if CCCL_VERSION >= 2008000
|
||||
atomic_add_general(out, val);
|
||||
#else
|
||||
bool cccl_version_too_old_for_bfloat16_atomic_add = false;
|
||||
assert(cccl_version_too_old_for_bfloat16_atomic_add);
|
||||
#endif
|
||||
#else
|
||||
atomicAdd(out, val);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
307
mlx/backend/cuda/device/binary_ops.cuh
Normal file
307
mlx/backend/cuda/device/binary_ops.cuh
Normal file
@ -0,0 +1,307 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda/std/array>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct Add {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct FloorDivide {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x / y;
|
||||
} else {
|
||||
return truncf(x / y);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
if constexpr (cuda::std::is_signed_v<T>) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
} else {
|
||||
return x % y;
|
||||
}
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return x % y;
|
||||
} else {
|
||||
T r = fmod(x, y);
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r = r + y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x == y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
if constexpr (std::is_same_v<T, cuComplex>) {
|
||||
return x == y ||
|
||||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) &&
|
||||
isnan(cuCimagf(y))) ||
|
||||
(cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) &&
|
||||
isnan(cuCimagf(y))) ||
|
||||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) &&
|
||||
cuCimagf(x) == cuCimagf(y));
|
||||
} else {
|
||||
return x == y || (isnan(x) && isnan(y));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x > y;
|
||||
}
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x >= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x <= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if (isnan(x) || isnan(y)) {
|
||||
return cuda::std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
T maxval = max(x, y);
|
||||
T minval = min(x, y);
|
||||
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
|
||||
maxval == cuda::std::numeric_limits<T>::infinity())
|
||||
? maxval
|
||||
: T(float(maxval) + log1p(expf(minval - maxval)));
|
||||
};
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
|
||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
|
||||
isnan(cuCimagf(y))) {
|
||||
return {
|
||||
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||
}
|
||||
float inf = cuda::std::numeric_limits<float>::infinity();
|
||||
auto maxval = x > y ? x : y;
|
||||
auto minval = x < y ? x : y;
|
||||
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
|
||||
return maxval;
|
||||
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
|
||||
cuComplex dexp{
|
||||
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
|
||||
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
|
||||
};
|
||||
return maxval + log1p(dexp);
|
||||
}
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return max(x, y);
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
} else {
|
||||
if (isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return min(x, y);
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
} else {
|
||||
if (isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
if constexpr (std::is_same_v<T, cuComplex>) {
|
||||
return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y);
|
||||
} else {
|
||||
return x != y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
template <typename T>
|
||||
__device__ T operator()(T base, T exp) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (base.y == 0 && base.x == 0) {
|
||||
if (isnan(exp.x) || isnan(exp.y)) {
|
||||
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
|
||||
return make_cuFloatComplex(nan, nan);
|
||||
}
|
||||
return make_cuFloatComplex(0.0, 0.0);
|
||||
}
|
||||
auto x_theta = atan2f(base.y, base.x);
|
||||
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
|
||||
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
|
||||
auto phase = exp.y * x_ln_r + exp.x * x_theta;
|
||||
return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase));
|
||||
} else {
|
||||
return powf(base, exp);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Subtract {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x && y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x || y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x & y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x | y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x << y;
|
||||
};
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T y, T x) {
|
||||
return atan2f(y, x);
|
||||
}
|
||||
};
|
||||
|
||||
struct DivMod {
|
||||
template <typename T>
|
||||
__device__ cuda::std::array<T, 2> operator()(T x, T y) {
|
||||
return {FloorDivide{}(x, y), Remainder{}(x, y)};
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
71
mlx/backend/cuda/device/cast_op.cuh
Normal file
71
mlx/backend/cuda/device/cast_op.cuh
Normal file
@ -0,0 +1,71 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// An op that does static_cast, with custom conversions for some types.
|
||||
template <typename SrcT, typename DstT, typename = void>
|
||||
struct CastOp {
|
||||
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, DstT>;
|
||||
|
||||
__device__ DstT operator()(SrcT x) {
|
||||
return static_cast<DstT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
// Converting a complex number to real number discards the imaginary part.
|
||||
template <typename DstT>
|
||||
struct CastOp<
|
||||
cuComplex,
|
||||
DstT,
|
||||
cuda::std::enable_if_t<!cuda::std::is_same_v<cuComplex, DstT>>> {
|
||||
static constexpr bool is_castable = cuda::std::is_convertible_v<float, DstT>;
|
||||
|
||||
__device__ DstT operator()(cuComplex x) {
|
||||
static_assert(!cuda::std::is_same_v<cuComplex, DstT>);
|
||||
return static_cast<DstT>(cuCrealf(x));
|
||||
}
|
||||
};
|
||||
|
||||
// Allow converting a real number to complex number.
|
||||
template <typename SrcT>
|
||||
struct CastOp<
|
||||
SrcT,
|
||||
cuComplex,
|
||||
cuda::std::enable_if_t<!cuda::std::is_same_v<SrcT, cuComplex>>> {
|
||||
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, float>;
|
||||
|
||||
__device__ cuComplex operator()(SrcT x) {
|
||||
static_assert(!cuda::std::is_same_v<SrcT, cuComplex>);
|
||||
return cuComplex{static_cast<float>(x), 0};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
struct CastOp<
|
||||
SrcT,
|
||||
DstT,
|
||||
cuda::std::enable_if_t<cuda::std::is_same_v<SrcT, DstT>>> {
|
||||
static constexpr bool is_castable = true;
|
||||
|
||||
__device__ SrcT operator()(SrcT x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
// Return an iterator that cast the value to DstT using CastOp.
|
||||
template <typename DstT, typename Iterator>
|
||||
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
|
||||
if constexpr (std::is_same_v<SrcT, DstT>) {
|
||||
return it;
|
||||
} else {
|
||||
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
12
mlx/backend/cuda/device/config.h
Normal file
12
mlx/backend/cuda/device/config.h
Normal file
@ -0,0 +1,12 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file is used by both CUDA kernel code and host-only C++ code.
|
||||
|
||||
#pragma once
|
||||
|
||||
// The maximum dimensions of shape/strides passed as kernel parameters.
|
||||
#define MAX_NDIM 10
|
||||
|
||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||
#define WARP_SIZE 32
|
240
mlx/backend/cuda/device/cucomplex_math.cuh
Normal file
240
mlx/backend/cuda/device/cucomplex_math.cuh
Normal file
@ -0,0 +1,240 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
// Copyright © 2017-2024 The Simons Foundation, Inc.
|
||||
//
|
||||
// FINUFFT is licensed under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance with the
|
||||
// License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Forked from
|
||||
// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
|
||||
// This header provides some helper functions for cuComplex types.
|
||||
// It mainly wraps existing CUDA implementations to provide operator overloads
|
||||
// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are
|
||||
// all provided by CUDA
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCadd(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCsub(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCmul(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCdiv(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b));
|
||||
double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b));
|
||||
return make_cuDoubleComplex(r, i);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator==(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator!=(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a));
|
||||
double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b));
|
||||
return mag_a > mag_b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>=(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return a > b || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return b > a;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<=(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return b > a || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator+(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator+(double a, const cuDoubleComplex& b) {
|
||||
return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator-(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator-(double a, const cuDoubleComplex& b) {
|
||||
return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator*(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator*(double a, const cuDoubleComplex& b) {
|
||||
return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator/(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator/(double a, const cuDoubleComplex& b) {
|
||||
double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b);
|
||||
return make_cuDoubleComplex(
|
||||
(a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator+(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCaddf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator-(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCsubf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator*(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCmulf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator/(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCdivf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator%(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b));
|
||||
float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b));
|
||||
return make_cuFloatComplex(r, i);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator==(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator!=(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a));
|
||||
float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b));
|
||||
return mag_a > mag_b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>=(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return a > b || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return b > a;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<=(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return b > a || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator+(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator+(float a, const cuFloatComplex& b) {
|
||||
return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator-(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator-(float a, const cuFloatComplex& b) {
|
||||
return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator*(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator*(float a, const cuFloatComplex& b) {
|
||||
return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator/(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator/(float a, const cuFloatComplex& b) {
|
||||
float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b);
|
||||
return make_cuFloatComplex(
|
||||
(a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom);
|
||||
}
|
194
mlx/backend/cuda/device/fp16_math.cuh
Normal file
194
mlx/backend/cuda/device/fp16_math.cuh
Normal file
@ -0,0 +1,194 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/type_traits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Unary ops for half types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x); \
|
||||
} else { \
|
||||
return ::NAME(x); \
|
||||
} \
|
||||
}
|
||||
#else
|
||||
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x); \
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||
return HALF_OP(x); \
|
||||
} else { \
|
||||
return ::NAME(x); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
#define MLX_DEFINE_UNARY_OP_FALLBCK(NAME) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return ::NAME(__half2float(x)); \
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||
return ::NAME(__bfloat162float(x)); \
|
||||
} else { \
|
||||
return ::NAME(x); \
|
||||
} \
|
||||
}
|
||||
|
||||
MLX_DEFINE_UNARY_OP(abs, __habs)
|
||||
MLX_DEFINE_UNARY_OP(ceil, hceil)
|
||||
MLX_DEFINE_UNARY_OP(cos, hcos)
|
||||
MLX_DEFINE_UNARY_OP(exp, hexp)
|
||||
MLX_DEFINE_UNARY_OP(floor, hfloor)
|
||||
MLX_DEFINE_UNARY_OP(isnan, __hisnan)
|
||||
MLX_DEFINE_UNARY_OP(log, hlog)
|
||||
MLX_DEFINE_UNARY_OP(log2, hlog2)
|
||||
MLX_DEFINE_UNARY_OP(log10, hlog10)
|
||||
MLX_DEFINE_UNARY_OP(rint, hrint)
|
||||
MLX_DEFINE_UNARY_OP(rsqrt, hrsqrt)
|
||||
MLX_DEFINE_UNARY_OP(sin, hsin)
|
||||
MLX_DEFINE_UNARY_OP(sqrt, hsqrt)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(acos)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(acosh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(asin)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(asinh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(atan)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(atanh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(cosh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(log1p)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(sinh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(tan)
|
||||
#if __CUDA_ARCH__ >= 1280
|
||||
MLX_DEFINE_UNARY_OP(tanh, htanh)
|
||||
#else
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(tanh)
|
||||
#endif
|
||||
|
||||
#undef MLX_DEFINE_UNARY_OP
|
||||
#undef MLX_DEFINE_UNARY_OP_FALLBCK
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Binary ops for half types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x, T y) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x, y); \
|
||||
} else { \
|
||||
return ::NAME(x, y); \
|
||||
} \
|
||||
}
|
||||
#else
|
||||
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x, T y) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x, y); \
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||
return HALF_OP(x, y); \
|
||||
} else { \
|
||||
return ::NAME(x, y); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
MLX_DEFINE_BINARY_OP(max, __hmax)
|
||||
MLX_DEFINE_BINARY_OP(min, __hmin)
|
||||
|
||||
#undef MLX_DEFINE_BINARY_OP
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T fmod(T x, T y) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return __float2half(::fmod(__half2float(x), __half2float(y)));
|
||||
#if CUDART_VERSION >= 12000 || __CUDA_ARCH__ >= 800
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return __float2bfloat16(::fmod(__bfloat162float(x), __bfloat162float(y)));
|
||||
#endif
|
||||
} else {
|
||||
return ::fmod(x, y);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Additional C++ operator overrides between half types and native types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_integral_except =
|
||||
cuda::std::is_integral_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_arithmetic_except =
|
||||
cuda::std::is_arithmetic_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(HALF x, T y) { \
|
||||
return FLOAT2HALF(HALF2FLOAT(x) OP static_cast<float>(y)); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(T x, HALF y) { \
|
||||
return FLOAT2HALF(static_cast<float>(x) OP HALF2FLOAT(y)); \
|
||||
}
|
||||
|
||||
#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(HALF x, T y) { \
|
||||
return HALF2FLOAT(x) OP static_cast<float>(y); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(T x, HALF y) { \
|
||||
return static_cast<float>(y) OP HALF2FLOAT(x); \
|
||||
}
|
||||
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, !=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=)
|
||||
|
||||
#undef MLX_DEFINE_HALF_OP
|
||||
#undef MLX_DEFINE_HALF_CMP
|
||||
|
||||
} // namespace mlx::core::cu
|
53
mlx/backend/cuda/device/gather.cuh
Normal file
53
mlx/backend/cuda/device/gather.cuh
Normal file
@ -0,0 +1,53 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
|
||||
__global__ void gather(
|
||||
const T* src,
|
||||
T* out,
|
||||
LocT size,
|
||||
const __grid_constant__ Shape src_shape,
|
||||
const __grid_constant__ Strides src_strides,
|
||||
int32_t src_ndim,
|
||||
const __grid_constant__ Shape slice_sizes,
|
||||
uint32_t slice_size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX> axes,
|
||||
const __grid_constant__ cuda::std::array<IdxT*, NIDX> indices,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX * IDX_NDIM>
|
||||
indices_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NIDX * IDX_NDIM>
|
||||
indices_strides) {
|
||||
LocT out_idx = cg::this_grid().thread_rank();
|
||||
if (out_idx >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
LocT src_elem = out_idx % slice_size;
|
||||
LocT idx_elem = out_idx / slice_size;
|
||||
|
||||
LocT src_loc =
|
||||
elem_to_loc(src_elem, slice_sizes.data(), src_strides.data(), src_ndim);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
LocT idx_loc = elem_to_loc_nd<IDX_NDIM>(
|
||||
idx_elem,
|
||||
indices_shape.data() + i * IDX_NDIM,
|
||||
indices_strides.data() + i * IDX_NDIM);
|
||||
int32_t axis = axes[i];
|
||||
LocT idx_val = absolute_index(indices[i][idx_loc], src_shape[axis]);
|
||||
src_loc += idx_val * src_strides[axis];
|
||||
}
|
||||
|
||||
out[out_idx] = src[src_loc];
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
65
mlx/backend/cuda/device/gather_axis.cuh
Normal file
65
mlx/backend/cuda/device/gather_axis.cuh
Normal file
@ -0,0 +1,65 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
int NDIM,
|
||||
bool SrcC,
|
||||
bool IdxC,
|
||||
typename LocT>
|
||||
__global__ void gather_axis(
|
||||
const T* src,
|
||||
const IdxT* indices,
|
||||
T* out,
|
||||
LocT idx_size_pre,
|
||||
LocT idx_size_axis,
|
||||
LocT idx_size_post,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> src_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> idx_strides,
|
||||
int32_t axis,
|
||||
int32_t axis_size,
|
||||
int64_t src_stride_axis,
|
||||
int64_t idx_stride_axis) {
|
||||
LocT index = cg::this_grid().thread_rank();
|
||||
if (index >= idx_size_pre * idx_size_axis * idx_size_post) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre);
|
||||
|
||||
LocT elem_idx = z * idx_size_post;
|
||||
|
||||
LocT idx_loc = y * idx_stride_axis;
|
||||
if constexpr (IdxC) {
|
||||
idx_loc += elem_idx * idx_size_axis + x;
|
||||
} else {
|
||||
idx_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), idx_strides.data());
|
||||
}
|
||||
|
||||
auto idx_val = absolute_index(indices[idx_loc], axis_size);
|
||||
|
||||
LocT src_loc = idx_val * src_stride_axis;
|
||||
if constexpr (SrcC) {
|
||||
src_loc += elem_idx * axis_size + x;
|
||||
} else {
|
||||
src_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), src_strides.data());
|
||||
}
|
||||
|
||||
LocT out_idx = y * idx_size_post + elem_idx * idx_size_axis + x;
|
||||
|
||||
out[out_idx] = src[src_loc];
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
30
mlx/backend/cuda/device/indexing.cuh
Normal file
30
mlx/backend/cuda/device/indexing.cuh
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <cuda/std/tuple>
|
||||
#include <cuda/std/type_traits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Convert an absolute index to positions in a 3d grid, assuming the index is
|
||||
// calculated with:
|
||||
// index = x * dim1 * dim2 + y * dim2 + z
|
||||
template <typename T>
|
||||
inline __host__ __device__ cuda::std::tuple<T, T, T>
|
||||
index_to_dims(T index, T dim1, T dim2) {
|
||||
T x = index / (dim1 * dim2);
|
||||
T y = (index % (dim1 * dim2)) / dim2;
|
||||
T z = index % dim2;
|
||||
return cuda::std::make_tuple(x, y, z);
|
||||
}
|
||||
|
||||
// Get absolute index from possible negative index.
|
||||
template <typename IdxT>
|
||||
inline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) {
|
||||
if constexpr (cuda::std::is_unsigned_v<IdxT>) {
|
||||
return idx;
|
||||
} else {
|
||||
return static_cast<int32_t>(idx < 0 ? idx + size : idx);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
68
mlx/backend/cuda/device/scatter.cuh
Normal file
68
mlx/backend/cuda/device/scatter.cuh
Normal file
@ -0,0 +1,68 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/scatter_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename Op,
|
||||
int NIDX,
|
||||
int IDX_NDIM,
|
||||
typename LocT>
|
||||
__global__ void scatter(
|
||||
const T* upd,
|
||||
T* out,
|
||||
LocT size,
|
||||
const __grid_constant__ Shape upd_shape,
|
||||
const __grid_constant__ Strides upd_strides,
|
||||
int32_t upd_ndim,
|
||||
LocT upd_post_idx_size,
|
||||
const __grid_constant__ Shape out_shape,
|
||||
const __grid_constant__ Strides out_strides,
|
||||
int32_t out_ndim,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX> axes,
|
||||
const __grid_constant__ cuda::std::array<IdxT*, NIDX> indices,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX * IDX_NDIM>
|
||||
indices_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NIDX * IDX_NDIM>
|
||||
indices_strides) {
|
||||
LocT upd_idx = cg::this_grid().thread_rank();
|
||||
if (upd_idx >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
LocT out_elem = upd_idx % upd_post_idx_size;
|
||||
LocT idx_elem = upd_idx / upd_post_idx_size;
|
||||
|
||||
LocT out_idx = elem_to_loc(
|
||||
out_elem, upd_shape.data() + IDX_NDIM, out_strides.data(), out_ndim);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
LocT idx_loc = elem_to_loc_nd<IDX_NDIM>(
|
||||
idx_elem,
|
||||
indices_shape.data() + i * IDX_NDIM,
|
||||
indices_strides.data() + i * IDX_NDIM);
|
||||
int32_t axis = axes[i];
|
||||
LocT idx_val = absolute_index(indices[i][idx_loc], out_shape[axis]);
|
||||
out_idx += idx_val * out_strides[axis];
|
||||
}
|
||||
|
||||
LocT upd_loc = elem_to_loc(
|
||||
out_elem + idx_elem * upd_post_idx_size,
|
||||
upd_shape.data(),
|
||||
upd_strides.data(),
|
||||
upd_ndim);
|
||||
|
||||
Op{}(out + out_idx, upd[upd_loc]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
67
mlx/backend/cuda/device/scatter_axis.cuh
Normal file
67
mlx/backend/cuda/device/scatter_axis.cuh
Normal file
@ -0,0 +1,67 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/scatter_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
bool UpdC,
|
||||
bool IdxC,
|
||||
typename LocT>
|
||||
__global__ void scatter_axis(
|
||||
const T* upd,
|
||||
const IdxT* indices,
|
||||
T* out,
|
||||
LocT idx_size_pre,
|
||||
LocT idx_size_axis,
|
||||
LocT idx_size_post,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> upd_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> idx_strides,
|
||||
int32_t axis,
|
||||
int32_t axis_size,
|
||||
int64_t upd_stride_axis,
|
||||
int64_t idx_stride_axis) {
|
||||
LocT index = cg::this_grid().thread_rank();
|
||||
if (index >= idx_size_pre * idx_size_axis * idx_size_post) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre);
|
||||
|
||||
LocT elem_idx = z * idx_size_post;
|
||||
|
||||
LocT idx_loc = y * idx_stride_axis;
|
||||
if constexpr (IdxC) {
|
||||
idx_loc += elem_idx * idx_size_axis + x;
|
||||
} else {
|
||||
idx_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), idx_strides.data());
|
||||
}
|
||||
|
||||
auto idx_val = absolute_index(indices[idx_loc], axis_size);
|
||||
|
||||
LocT upd_loc = y * upd_stride_axis;
|
||||
if constexpr (UpdC) {
|
||||
upd_loc += elem_idx * idx_size_axis + x;
|
||||
} else {
|
||||
upd_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), upd_strides.data());
|
||||
}
|
||||
|
||||
LocT out_idx = idx_val * idx_size_post + elem_idx * axis_size + x;
|
||||
|
||||
Op{}(out + out_idx, upd[upd_loc]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
44
mlx/backend/cuda/device/scatter_ops.cuh
Normal file
44
mlx/backend/cuda/device/scatter_ops.cuh
Normal file
@ -0,0 +1,44 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/atomic_ops.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct ScatterAssign {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
*out = val;
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterSum {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_add(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterProd {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_prod(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterMax {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_max(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterMin {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_min(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
13
mlx/backend/cuda/device/ternary_ops.cuh
Normal file
13
mlx/backend/cuda/device/ternary_ops.cuh
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct Select {
|
||||
template <typename T>
|
||||
__device__ T operator()(bool condition, T x, T y) {
|
||||
return condition ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
368
mlx/backend/cuda/device/unary_ops.cuh
Normal file
368
mlx/backend/cuda/device/unary_ops.cuh
Normal file
@ -0,0 +1,368 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <math_constants.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct Abs {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||
return x;
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0};
|
||||
} else {
|
||||
return abs(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return acos(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return acosh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return asin(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return asinh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return atan(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return atanh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseInvert {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return ~x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x;
|
||||
} else {
|
||||
return ceil(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
__device__ cuComplex operator()(cuComplex x) {
|
||||
return {cuCrealf(x), -cuCimagf(x)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
cos(cuCrealf(x)) * cosh(cuCimagf(x)),
|
||||
-sin(cuCrealf(x)) * sinh(cuCimagf(x))};
|
||||
} else {
|
||||
return cos(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
cosh(cuCrealf(x)) * cos(cuCimagf(x)),
|
||||
sinh(cuCrealf(x)) * sin(cuCimagf(x))};
|
||||
} else {
|
||||
return cosh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return erf(__half2float(x));
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return erf(__bfloat162float(x));
|
||||
} else {
|
||||
return erf(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return erfinv(__half2float(x));
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return erfinv(__bfloat162float(x));
|
||||
} else {
|
||||
return erfinv(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto m = exp(cuCrealf(x));
|
||||
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
|
||||
} else {
|
||||
return exp(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Expm1 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return expm1(__half2float(x));
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return expm1(__bfloat162float(x));
|
||||
} else {
|
||||
return expm1(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x;
|
||||
} else {
|
||||
return floor(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Imag {
|
||||
__device__ float operator()(cuComplex x) {
|
||||
return cuCimagf(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto r = log(cuCrealf(Abs{}(x)));
|
||||
auto i = atan2f(cuCimagf(x), cuCrealf(x));
|
||||
return {r, i};
|
||||
} else {
|
||||
return log(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto y = Log{}(x);
|
||||
return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F};
|
||||
} else {
|
||||
return log2(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto y = Log{}(x);
|
||||
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
|
||||
return y;
|
||||
} else {
|
||||
return log10(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log1p(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
__device__ bool operator()(bool x) {
|
||||
return !x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return 0 - x;
|
||||
} else {
|
||||
return -x;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Real {
|
||||
__device__ float operator()(cuComplex x) {
|
||||
return cuCrealf(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {rint(cuCrealf(x)), rint(cuCimagf(x))};
|
||||
} else {
|
||||
return rint(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
T y = 1 / (1 + exp(-abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||
return x != 0;
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (cuCrealf(x) == 0 && cuCimagf(x) == 0) {
|
||||
return x;
|
||||
} else {
|
||||
return x / Abs()(x);
|
||||
}
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return static_cast<float>((x > T(0.f)) - (x < T(0.f)));
|
||||
} else {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
sin(cuCrealf(x)) * cosh(cuCimagf(x)),
|
||||
cos(cuCrealf(x)) * sinh(cuCimagf(x))};
|
||||
} else {
|
||||
return sin(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
sinh(cuCrealf(x)) * cos(cuCimagf(x)),
|
||||
cosh(cuCrealf(x)) * sin(cuCimagf(x))};
|
||||
} else {
|
||||
return sinh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return x * x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
float tan_a = tan(cuCrealf(x));
|
||||
float tanh_b = tanh(cuCimagf(x));
|
||||
float t1 = tan_a * tanh_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
|
||||
} else {
|
||||
return tan(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
float tanh_a = tanh(cuCrealf(x));
|
||||
float tan_b = tan(cuCimagf(x));
|
||||
float t1 = tanh_a * tan_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
||||
} else {
|
||||
return tanh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
358
mlx/backend/cuda/device/utils.cuh
Normal file
358
mlx/backend/cuda/device/utils.cuh
Normal file
@ -0,0 +1,358 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file must not include any host-only code, utilies that work under both
|
||||
// host and device can be put here.
|
||||
//
|
||||
// See more about the requirements at:
|
||||
// https://docs.nvidia.com/cuda/nvrtc/#language
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda/std/array>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/tuple>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CUDA kernel utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// To pass shape/strides to kernels via constant memory, their size must be
|
||||
// known at compile time.
|
||||
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
||||
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct Limits {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
return cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
return cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Limits<
|
||||
T,
|
||||
cuda::std::enable_if_t<
|
||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double>>> {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
return -cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
return cuda::std::numeric_limits<T>::lowest();
|
||||
}
|
||||
};
|
||||
|
||||
// CUDA 11 does not have host side arithmatic operators for half types.
|
||||
template <typename T>
|
||||
struct Limits<
|
||||
T,
|
||||
cuda::std::enable_if_t<
|
||||
cuda::std::is_same_v<T, __half> ||
|
||||
cuda::std::is_same_v<T, __nv_bfloat16>>> {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||
return -cuda::std::numeric_limits<T>::infinity();
|
||||
#else
|
||||
return -cuda::std::numeric_limits<float>::infinity();
|
||||
#endif
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||
return cuda::std::numeric_limits<T>::lowest();
|
||||
#else
|
||||
return cuda::std::numeric_limits<float>::lowest();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Limits<bool> {
|
||||
static constexpr __host__ __device__ bool max() {
|
||||
return true;
|
||||
}
|
||||
static constexpr __host__ __device__ bool min() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Limits<cuComplex> {
|
||||
static constexpr __host__ __device__ cuComplex max() {
|
||||
return {Limits<float>::max(), Limits<float>::max()};
|
||||
}
|
||||
static constexpr __host__ __device__ cuComplex min() {
|
||||
return {Limits<float>::min(), Limits<float>::min()};
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
// Optimize when the ndim is known at compile time.
|
||||
template <int NDIM, typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) {
|
||||
IdxT loc = 0;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM, typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
}
|
||||
|
||||
template <int NDIM, typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
const int64_t* c_strides) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
IdxT c_loc = 0;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||
}
|
||||
|
||||
// Optimized version when ndim is larger than 4.
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
int ndim) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
}
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
const int64_t* c_strides,
|
||||
int ndim) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
IdxT c_loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Elem to loc in a loop utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int DIM, bool General = true, typename OffsetT = size_t>
|
||||
struct LoopedElemToLoc {
|
||||
int dim;
|
||||
LoopedElemToLoc<DIM - 1, General, OffsetT> inner_looper;
|
||||
OffsetT offset{0};
|
||||
int index{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
|
||||
|
||||
__device__ void next(const int* shape, const int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index++;
|
||||
offset += OffsetT(strides[dim - 1]);
|
||||
if (index >= shape[dim - 1]) {
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
offset = inner_looper.offset;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index += n;
|
||||
offset += n * OffsetT(strides[dim - 1]);
|
||||
|
||||
if (index >= shape[dim - 1]) {
|
||||
int extra = index - shape[dim - 1];
|
||||
if (extra >= shape[dim - 1]) {
|
||||
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
|
||||
extra = extra % shape[dim - 1];
|
||||
} else {
|
||||
inner_looper.next(shape, strides);
|
||||
}
|
||||
index = 0;
|
||||
offset = inner_looper.offset;
|
||||
if (extra > 0) {
|
||||
next(extra, shape, strides);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, true, OffsetT> {
|
||||
int dim;
|
||||
OffsetT offset{0};
|
||||
int index{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int dim) : dim(dim) {}
|
||||
|
||||
__device__ void next(const int* shape, const int64_t* strides) {
|
||||
index++;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||
index += n;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset = index * OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, false, OffsetT> {
|
||||
OffsetT offset{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int) {}
|
||||
|
||||
__device__ void next(const int*, const int64_t* strides) {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int*, const int64_t* strides) {
|
||||
offset += n * OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
inline __device__ cuComplex log1p(cuComplex in) {
|
||||
float x = cuCrealf(in);
|
||||
float y = cuCimagf(in);
|
||||
float zabs = sqrt(x * x + y * y);
|
||||
float theta = atan2f(y, x + 1);
|
||||
if (zabs < 0.5f) {
|
||||
float r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return {x, theta};
|
||||
}
|
||||
return {0.5f * log1pf(r), theta};
|
||||
} else {
|
||||
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
|
||||
return {log(z0), theta};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
@ -62,7 +62,7 @@ void finalize(Stream s) {
|
||||
|
||||
void synchronize(Stream s) {
|
||||
nvtx3::scoped_range r("gpu::synchronize");
|
||||
cu::get_stream(s).synchronize();
|
||||
cu::get_command_encoder(s).synchronize();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::gpu
|
||||
|
@ -1,5 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
@ -111,12 +112,12 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
|
||||
SharedEvent::SharedEvent() {
|
||||
// Allocate cuda::atomic on managed memory.
|
||||
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic));
|
||||
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr());
|
||||
Atomic* ac;
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
|
||||
new (ac) Atomic(0);
|
||||
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) {
|
||||
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
|
||||
ptr->~Atomic();
|
||||
allocator::free(buffer);
|
||||
allocator().cuda_free(ptr);
|
||||
});
|
||||
}
|
||||
|
||||
@ -155,7 +156,10 @@ void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this, value]() mutable { signal(value); });
|
||||
// Signal through a GPU stream so the atomic is updated in GPU - updating
|
||||
// the atomic in CPU sometimes does not get GPU notified.
|
||||
static CudaStream stream(device(mlx::core::Device::gpu));
|
||||
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
|
||||
} else {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
|
29
mlx/backend/cuda/fence.cpp
Normal file
29
mlx/backend/cuda/fence.cpp
Normal file
@ -0,0 +1,29 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
struct FenceImpl {
|
||||
uint32_t count;
|
||||
cu::SharedEvent event;
|
||||
};
|
||||
|
||||
Fence::Fence(Stream s) {
|
||||
fence_ = std::shared_ptr<void>(
|
||||
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
|
||||
}
|
||||
|
||||
void Fence::wait(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
fence->event.wait(fence->count);
|
||||
}
|
||||
|
||||
void Fence::update(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
fence->count++;
|
||||
fence->event.signal(s, fence->count);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1,70 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
__host__ __device__ void busy_wait(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
||||
while (true) {
|
||||
// In theory the atomic_thread_fence is not needed, but for CUDA 11 without
|
||||
// it the load() may never return new value.
|
||||
cuda::atomic_thread_fence(cuda::memory_order_seq_cst);
|
||||
uint64_t current = ac->load();
|
||||
if (current >= value) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void busy_wait_kernel(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
||||
busy_wait(ac, value);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
struct FenceImpl {
|
||||
uint32_t count;
|
||||
cu::SharedEvent event;
|
||||
};
|
||||
|
||||
Fence::Fence(Stream s) {
|
||||
fence_ = std::shared_ptr<void>(
|
||||
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
|
||||
}
|
||||
|
||||
void Fence::wait(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
// We can't use SharedEvent::wait because it could hang in CUDA 11, see also:
|
||||
// https://github.com/ml-explore/mlx/issues/2137
|
||||
const auto& ac = fence->event.atomic();
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [ac, count = fence->count]() {
|
||||
nvtx3::scoped_range r("Fence::wait()");
|
||||
busy_wait(ac.get(), count);
|
||||
});
|
||||
} else {
|
||||
nvtx3::scoped_range r("Fence::wait(s)");
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) {
|
||||
busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count);
|
||||
});
|
||||
encoder.add_completed_handler([ac]() {});
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
void Fence::update(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
fence->count++;
|
||||
fence->event.signal(s, fence->count);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
420
mlx/backend/cuda/indexing.cpp
Normal file
420
mlx/backend/cuda/indexing.cpp
Normal file
@ -0,0 +1,420 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include "cuda_jit_sources.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"};
|
||||
|
||||
void append_indices_arg(
|
||||
cu::JitModule& mod,
|
||||
const std::vector<array>& inputs,
|
||||
int nidx,
|
||||
int idx_ndim) {
|
||||
std::vector<const void*> indices(nidx);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
indices[i] = inputs[i + 1].data<void>();
|
||||
}
|
||||
mod.append_arg(std::move(indices));
|
||||
std::vector<int32_t> indices_shape(nidx * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy_n(
|
||||
inputs[i + 1].shape().begin(),
|
||||
idx_ndim,
|
||||
indices_shape.data() + i * idx_ndim);
|
||||
}
|
||||
mod.append_arg(std::move(indices_shape));
|
||||
std::vector<int64_t> indices_strides(nidx * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy_n(
|
||||
inputs[i + 1].strides().begin(),
|
||||
idx_ndim,
|
||||
indices_strides.data() + i * idx_ndim);
|
||||
}
|
||||
mod.append_arg(std::move(indices_strides));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Gather::eval_gpu");
|
||||
assert(inputs.size() > 0);
|
||||
const auto& src = inputs[0];
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int nidx = inputs.size() - 1;
|
||||
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||
|
||||
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
|
||||
(src.size() > INT32_MAX) || (out.size() > INT32_MAX);
|
||||
|
||||
uint32_t slice_size = std::accumulate(
|
||||
slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies<uint32_t>());
|
||||
|
||||
std::string module_name = fmt::format(
|
||||
"gather_{}_{}_{}",
|
||||
dtype_to_string(out.dtype()),
|
||||
dtype_to_string(idx_dtype),
|
||||
nidx);
|
||||
|
||||
auto& s = stream();
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||
for (int large = 0; large <= 1; ++large) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::gather<{}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
nidx,
|
||||
ndim,
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
||||
});
|
||||
|
||||
mod.append_arg(src);
|
||||
mod.append_arg(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(out.size());
|
||||
} else {
|
||||
mod.append_arg<int32_t>(out.size());
|
||||
}
|
||||
mod.append_ndim_arg(src.shape());
|
||||
mod.append_ndim_arg(src.strides());
|
||||
mod.append_arg<int32_t>(src.ndim());
|
||||
mod.append_ndim_arg(slice_sizes_);
|
||||
mod.append_arg(slice_size);
|
||||
mod.append_arg(axes_);
|
||||
append_indices_arg(mod, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::gather<{}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, out, large);
|
||||
});
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Gather::eval_gpu");
|
||||
assert(inputs.size() > 1);
|
||||
auto& upd = inputs.back();
|
||||
|
||||
// Copy src into out.
|
||||
CopyType copy_type;
|
||||
if (inputs[0].data_size() == 1) {
|
||||
copy_type = CopyType::Scalar;
|
||||
} else if (inputs[0].flags().row_contiguous) {
|
||||
copy_type = CopyType::Vector;
|
||||
} else {
|
||||
copy_type = CopyType::General;
|
||||
}
|
||||
copy_gpu(inputs[0], out, copy_type);
|
||||
|
||||
// Empty update.
|
||||
if (upd.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int nidx = axes_.size();
|
||||
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||
|
||||
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
|
||||
(upd.size() > INT32_MAX) || (out.size() > INT32_MAX);
|
||||
|
||||
int32_t upd_post_idx_size = std::accumulate(
|
||||
upd.shape().begin() + idx_ndim,
|
||||
upd.shape().end(),
|
||||
1,
|
||||
std::multiplies<int32_t>());
|
||||
|
||||
const char* op = g_scatter_ops[reduce_type_];
|
||||
std::string module_name = fmt::format(
|
||||
"scatter_{}_{}_{}_{}",
|
||||
dtype_to_string(out.dtype()),
|
||||
dtype_to_string(idx_dtype),
|
||||
op,
|
||||
nidx);
|
||||
|
||||
auto& s = stream();
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||
for (int large = 0; large <= 1; ++large) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
op,
|
||||
nidx,
|
||||
ndim,
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
||||
});
|
||||
|
||||
mod.append_arg(upd);
|
||||
mod.append_arg(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(upd.size());
|
||||
} else {
|
||||
mod.append_arg<int32_t>(upd.size());
|
||||
}
|
||||
mod.append_ndim_arg(upd.shape());
|
||||
mod.append_ndim_arg(upd.strides());
|
||||
mod.append_arg<int32_t>(upd.ndim());
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(upd_post_idx_size);
|
||||
} else {
|
||||
mod.append_arg<int32_t>(upd_post_idx_size);
|
||||
}
|
||||
mod.append_ndim_arg(out.shape());
|
||||
mod.append_ndim_arg(out.strides());
|
||||
mod.append_arg<int32_t>(out.ndim());
|
||||
mod.append_arg(axes_);
|
||||
append_indices_arg(mod, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
op,
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, upd, large);
|
||||
});
|
||||
}
|
||||
|
||||
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("GatherAxis::eval_gpu");
|
||||
assert(inputs.size() > 1);
|
||||
const auto& src = inputs[0];
|
||||
const auto& idx = inputs[1];
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
std::string module_name = fmt::format(
|
||||
"gather_axis_{}_{}",
|
||||
dtype_to_string(out.dtype()),
|
||||
dtype_to_string(idx.dtype()));
|
||||
|
||||
auto& s = stream();
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||
for (int contiguous = 0; contiguous < 4; ++contiguous) {
|
||||
for (int large = 0; large <= 1; ++large) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx.dtype()),
|
||||
ndim,
|
||||
contiguous & 1 ? true : false,
|
||||
contiguous & 2 ? true : false,
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_gather_axis, std::move(kernel_names));
|
||||
});
|
||||
|
||||
size_t idx_size_pre = 1;
|
||||
size_t idx_size_post = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
idx_size_pre *= idx.shape(i);
|
||||
}
|
||||
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
|
||||
idx_size_post *= idx.shape(i);
|
||||
}
|
||||
size_t idx_size_axis = idx.shape(axis_);
|
||||
|
||||
mod.append_arg(src);
|
||||
mod.append_arg(idx);
|
||||
mod.append_arg(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(idx_size_pre);
|
||||
mod.append_arg<int64_t>(idx_size_axis);
|
||||
mod.append_arg<int64_t>(idx_size_post);
|
||||
} else {
|
||||
mod.append_arg<int32_t>(idx_size_pre);
|
||||
mod.append_arg<int32_t>(idx_size_axis);
|
||||
mod.append_arg<int32_t>(idx_size_post);
|
||||
}
|
||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||
mod.append_arg(remove_index(src.strides(), axis_));
|
||||
mod.append_arg(remove_index(idx.strides(), axis_));
|
||||
mod.append_arg<int32_t>(axis_);
|
||||
mod.append_arg(src.shape(axis_));
|
||||
mod.append_arg(src.strides(axis_));
|
||||
mod.append_arg(idx.strides(axis_));
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx.dtype()),
|
||||
src.ndim() - 1,
|
||||
src.flags().row_contiguous,
|
||||
idx.flags().row_contiguous,
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, idx, large);
|
||||
});
|
||||
}
|
||||
|
||||
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ScatterAxis::eval_gpu");
|
||||
assert(inputs.size() > 2);
|
||||
const auto& src = inputs[0];
|
||||
const auto& idx = inputs[1];
|
||||
const auto& upd = inputs[2];
|
||||
|
||||
// Copy src into out.
|
||||
CopyType copy_type;
|
||||
if (src.data_size() == 1) {
|
||||
copy_type = CopyType::Scalar;
|
||||
} else if (src.flags().row_contiguous) {
|
||||
copy_type = CopyType::Vector;
|
||||
} else {
|
||||
copy_type = CopyType::General;
|
||||
}
|
||||
copy_gpu(src, out, copy_type);
|
||||
|
||||
// Empty update.
|
||||
if (upd.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign";
|
||||
std::string module_name = fmt::format(
|
||||
"scatter_axis_{}_{}_{}",
|
||||
dtype_to_string(out.dtype()),
|
||||
dtype_to_string(idx.dtype()),
|
||||
op);
|
||||
|
||||
auto& s = stream();
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||
for (int contiguous = 0; contiguous < 4; ++contiguous) {
|
||||
for (int large = 0; large <= 1; ++large) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx.dtype()),
|
||||
op,
|
||||
ndim,
|
||||
contiguous & 1 ? true : false,
|
||||
contiguous & 2 ? true : false,
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_scatter_axis, std::move(kernel_names));
|
||||
});
|
||||
|
||||
size_t idx_size_pre = 1;
|
||||
size_t idx_size_post = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
idx_size_pre *= idx.shape(i);
|
||||
}
|
||||
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
|
||||
idx_size_post *= idx.shape(i);
|
||||
}
|
||||
size_t idx_size_axis = idx.shape(axis_);
|
||||
|
||||
mod.append_arg(upd);
|
||||
mod.append_arg(idx);
|
||||
mod.append_arg(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(idx_size_pre);
|
||||
mod.append_arg<int64_t>(idx_size_axis);
|
||||
mod.append_arg<int64_t>(idx_size_post);
|
||||
} else {
|
||||
mod.append_arg<int32_t>(idx_size_pre);
|
||||
mod.append_arg<int32_t>(idx_size_axis);
|
||||
mod.append_arg<int32_t>(idx_size_post);
|
||||
}
|
||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||
mod.append_arg(remove_index(upd.strides(), axis_));
|
||||
mod.append_arg(remove_index(idx.strides(), axis_));
|
||||
mod.append_arg<int32_t>(axis_);
|
||||
mod.append_arg(out.shape(axis_));
|
||||
mod.append_arg(upd.strides(axis_));
|
||||
mod.append_arg(idx.strides(axis_));
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx.dtype()),
|
||||
op,
|
||||
idx.ndim() - 1,
|
||||
upd.flags().row_contiguous,
|
||||
idx.flags().row_contiguous,
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, idx, large);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
121
mlx/backend/cuda/iterators/general_iterator.cuh
Normal file
121
mlx/backend/cuda/iterators/general_iterator.cuh
Normal file
@ -0,0 +1,121 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <thrust/iterator/iterator_adaptor.h>
|
||||
#include <cuda/std/utility>
|
||||
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Iterating non-contiguous array.
|
||||
template <typename Iterator, typename IdxT = int64_t>
|
||||
class general_iterator
|
||||
: public thrust::
|
||||
iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator> {
|
||||
public:
|
||||
using super_t =
|
||||
thrust::iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator>;
|
||||
|
||||
using reference = typename super_t::reference;
|
||||
using difference_type = typename super_t::difference_type;
|
||||
|
||||
__host__ __device__ general_iterator(
|
||||
Iterator it,
|
||||
IdxT index,
|
||||
int ndim,
|
||||
Shape shape,
|
||||
Strides strides)
|
||||
: super_t(it),
|
||||
index_(index),
|
||||
ndim_(ndim),
|
||||
shape_(cuda::std::move(shape)),
|
||||
strides_(cuda::std::move(strides)) {}
|
||||
|
||||
__host__ __device__ IdxT index() const {
|
||||
return index_;
|
||||
}
|
||||
|
||||
__host__ __device__ const Shape& shape() const {
|
||||
return shape_;
|
||||
}
|
||||
|
||||
__host__ __device__ const Strides& strides() const {
|
||||
return strides_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class thrust::iterator_core_access;
|
||||
|
||||
__host__ __device__ bool equal(const general_iterator& other) const {
|
||||
return this->base() == other.base() && this->index() == other.index();
|
||||
}
|
||||
|
||||
__host__ __device__ void advance(difference_type n) {
|
||||
this->index_ += n;
|
||||
}
|
||||
|
||||
__host__ __device__ void increment() {
|
||||
this->index_ += 1;
|
||||
}
|
||||
|
||||
__host__ __device__ void decrement() {
|
||||
this->index_ -= 1;
|
||||
}
|
||||
|
||||
__host__ __device__ difference_type
|
||||
distance_to(const general_iterator& other) const {
|
||||
_CCCL_ASSERT(
|
||||
this->base() == other.base(),
|
||||
"Underlying iterator must point to same base iterator");
|
||||
return other.index() - this->index();
|
||||
}
|
||||
|
||||
// The dereference is device-only to avoid accidental running in host.
|
||||
__device__ typename super_t::reference dereference() const {
|
||||
IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_);
|
||||
return *(this->base() + offset);
|
||||
}
|
||||
|
||||
IdxT index_;
|
||||
int ndim_;
|
||||
Shape shape_;
|
||||
Strides strides_;
|
||||
};
|
||||
|
||||
template <typename IdxT, typename Iterator>
|
||||
__host__ __device__ auto make_general_iterator(
|
||||
Iterator it,
|
||||
IdxT index,
|
||||
int ndim,
|
||||
Shape shape,
|
||||
Strides strides) {
|
||||
return general_iterator<Iterator, IdxT>(
|
||||
it, index, ndim, cuda::std::move(shape), cuda::std::move(strides));
|
||||
}
|
||||
|
||||
template <typename IdxT, typename Iterator>
|
||||
auto make_general_iterator(
|
||||
Iterator it,
|
||||
const std::vector<int32_t>& shape,
|
||||
const std::vector<int64_t>& strides) {
|
||||
return make_general_iterator<IdxT>(
|
||||
it, 0, shape.size(), const_param(shape), const_param(strides));
|
||||
}
|
||||
|
||||
template <typename IdxT, typename Iterator>
|
||||
auto make_general_iterators(
|
||||
Iterator it,
|
||||
IdxT size,
|
||||
const std::vector<int32_t>& shape,
|
||||
const std::vector<int64_t>& strides) {
|
||||
auto ndim = shape.size();
|
||||
auto shape_arg = const_param(shape);
|
||||
auto strides_arg = const_param(strides);
|
||||
return std::make_pair(
|
||||
make_general_iterator<IdxT>(it, 0, ndim, shape_arg, strides_arg),
|
||||
make_general_iterator<IdxT>(it, size, ndim, shape_arg, strides_arg));
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
@ -0,0 +1,60 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <thrust/iterator/iterator_adaptor.h>
|
||||
#include <thrust/iterator/iterator_facade.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// RandomAccessIterator for strided access to array entries.
|
||||
template <typename Iterator, typename Stride = int64_t>
|
||||
class strided_iterator
|
||||
: public thrust::
|
||||
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
|
||||
public:
|
||||
using super_t =
|
||||
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
|
||||
|
||||
using reference = typename super_t::reference;
|
||||
using difference_type = typename super_t::difference_type;
|
||||
|
||||
__host__ __device__ strided_iterator(Iterator it, Stride stride)
|
||||
: super_t(it), stride_(stride) {}
|
||||
|
||||
__host__ __device__ Stride stride() const {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class thrust::iterator_core_access;
|
||||
|
||||
__host__ __device__ bool equal(const strided_iterator& other) const {
|
||||
return this->base() == other.base();
|
||||
}
|
||||
|
||||
__host__ __device__ void advance(difference_type n) {
|
||||
this->base_reference() += n * stride_;
|
||||
}
|
||||
|
||||
__host__ __device__ void increment() {
|
||||
this->base_reference() += stride_;
|
||||
}
|
||||
|
||||
__host__ __device__ void decrement() {
|
||||
this->base_reference() -= stride_;
|
||||
}
|
||||
|
||||
__host__ __device__ difference_type
|
||||
distance_to(const strided_iterator& other) const {
|
||||
const difference_type dist = other.base() - this->base();
|
||||
_CCCL_ASSERT(
|
||||
dist % stride() == 0,
|
||||
"Underlying iterator difference must be divisible by the stride");
|
||||
return dist / stride();
|
||||
}
|
||||
|
||||
Stride stride_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
364
mlx/backend/cuda/jit_module.cpp
Normal file
364
mlx/backend/cuda/jit_module.cpp
Normal file
@ -0,0 +1,364 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
#include "cuda_jit_sources.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvrtc.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace {
|
||||
|
||||
#define CHECK_NVRTC_ERROR(cmd) check_nvrtc_error(#cmd, (cmd))
|
||||
|
||||
void check_nvrtc_error(const char* name, nvrtcResult err) {
|
||||
if (err != NVRTC_SUCCESS) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed: {}", name, nvrtcGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_CU_ERROR(cmd) check_cu_error(#cmd, (cmd))
|
||||
|
||||
void check_cu_error(const char* name, CUresult err) {
|
||||
if (err != CUDA_SUCCESS) {
|
||||
const char* err_str = "Unknown error";
|
||||
cuGetErrorString(err, &err_str);
|
||||
throw std::runtime_error(fmt::format("{} failed: {}", name, err_str));
|
||||
}
|
||||
}
|
||||
|
||||
// Return the location of the CUDA toolkit.
|
||||
const std::string& cuda_home() {
|
||||
static std::string home = []() -> std::string {
|
||||
const char* home = std::getenv("CUDA_HOME");
|
||||
if (home) {
|
||||
return home;
|
||||
}
|
||||
home = std::getenv("CUDA_PATH");
|
||||
if (home) {
|
||||
return home;
|
||||
}
|
||||
#if defined(__linux__)
|
||||
home = "/usr/local/cuda";
|
||||
if (std::filesystem::exists(home)) {
|
||||
return home;
|
||||
}
|
||||
#endif
|
||||
throw std::runtime_error(
|
||||
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
||||
}();
|
||||
return home;
|
||||
}
|
||||
|
||||
// Get the cache directory for storing compiled results.
|
||||
const std::filesystem::path& ptx_cache_dir() {
|
||||
static std::filesystem::path cache = []() -> std::filesystem::path {
|
||||
std::filesystem::path cache;
|
||||
if (auto c = std::getenv("MLX_PTX_CACHE"); c) {
|
||||
cache = c;
|
||||
} else {
|
||||
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
||||
}
|
||||
if (!std::filesystem::exists(cache)) {
|
||||
std::error_code error;
|
||||
if (!std::filesystem::create_directories(cache, error)) {
|
||||
return std::filesystem::path();
|
||||
}
|
||||
}
|
||||
return cache;
|
||||
}();
|
||||
return cache;
|
||||
}
|
||||
|
||||
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
||||
bool read_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
std::vector<char>* ptx,
|
||||
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||
if (cache_dir.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto ptx_path = cache_dir / (module_name + ".ptx");
|
||||
std::error_code error;
|
||||
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
||||
if (error) {
|
||||
return false;
|
||||
}
|
||||
std::ifstream ptx_file(ptx_path, std::ios::binary);
|
||||
if (!ptx_file.good()) {
|
||||
return false;
|
||||
}
|
||||
ptx->resize(ptx_size);
|
||||
ptx_file.read(ptx->data(), ptx_size);
|
||||
|
||||
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||
std::string line;
|
||||
while (std::getline(txt_file, line)) {
|
||||
auto tab = line.find('\t');
|
||||
if (tab != std::string::npos) {
|
||||
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Write the |ptx| and |ptx_kernels| to |cache_dir| with |name|.
|
||||
void write_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
const std::vector<char>& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
if (cache_dir.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
||||
if (!ptx.empty()) {
|
||||
ptx_file.write(&ptx.front(), ptx.size());
|
||||
}
|
||||
std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
txt_file << name << "\t" << mangled << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Return if |device|'s version is not newer than |major|.|minor| version.
|
||||
inline bool version_lower_equal(Device& device, int major, int minor) {
|
||||
if (device.compute_capability_major() < major) {
|
||||
return true;
|
||||
} else if (device.compute_capability_major() == major) {
|
||||
return device.compute_capability_minor() <= minor;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Return whether NVRTC supports compiling to |device|'s SASS code.
|
||||
bool compiler_supports_device_sass(Device& device) {
|
||||
int nvrtc_major, nvrtc_minor;
|
||||
CHECK_NVRTC_ERROR(nvrtcVersion(&nvrtc_major, &nvrtc_minor));
|
||||
if (nvrtc_major < 9) {
|
||||
return false;
|
||||
} else if (nvrtc_major == 9) {
|
||||
return version_lower_equal(device, 7, 2);
|
||||
} else if (nvrtc_major == 10) {
|
||||
return version_lower_equal(device, 7, 5);
|
||||
} else if (nvrtc_major == 11 && nvrtc_minor == 0) {
|
||||
return version_lower_equal(device, 8, 0);
|
||||
} else if (nvrtc_major == 11 && nvrtc_minor < 8) {
|
||||
return version_lower_equal(device, 8, 6);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
#define INCLUDE_PREFIX "mlx/backend/cuda/device/"
|
||||
|
||||
constexpr const char* g_include_names[] = {
|
||||
INCLUDE_PREFIX "atomic_ops.cuh",
|
||||
INCLUDE_PREFIX "binary_ops.cuh",
|
||||
INCLUDE_PREFIX "cast_op.cuh",
|
||||
INCLUDE_PREFIX "config.h",
|
||||
INCLUDE_PREFIX "cucomplex_math.cuh",
|
||||
INCLUDE_PREFIX "fp16_math.cuh",
|
||||
INCLUDE_PREFIX "indexing.cuh",
|
||||
INCLUDE_PREFIX "scatter_ops.cuh",
|
||||
INCLUDE_PREFIX "unary_ops.cuh",
|
||||
INCLUDE_PREFIX "ternary_ops.cuh",
|
||||
INCLUDE_PREFIX "utils.cuh",
|
||||
};
|
||||
|
||||
#undef INCLUDE_PREFIX
|
||||
|
||||
constexpr const char* g_headers[] = {
|
||||
jit_source_atomic_ops,
|
||||
jit_source_binary_ops,
|
||||
jit_source_cast_op,
|
||||
jit_source_config,
|
||||
jit_source_cucomplex_math,
|
||||
jit_source_fp16_math,
|
||||
jit_source_indexing,
|
||||
jit_source_scatter_ops,
|
||||
jit_source_unary_ops,
|
||||
jit_source_ternary_ops,
|
||||
jit_source_utils,
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
JitModule::JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder) {
|
||||
// Check cache.
|
||||
std::vector<char> ptx;
|
||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
||||
// Create program.
|
||||
auto [source_code, kernel_names] = builder();
|
||||
nvrtcProgram prog;
|
||||
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
|
||||
&prog,
|
||||
source_code.c_str(),
|
||||
(module_name + ".cu").c_str(),
|
||||
std::size(g_headers),
|
||||
g_headers,
|
||||
g_include_names));
|
||||
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
|
||||
&prog,
|
||||
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
|
||||
for (const auto& name : kernel_names) {
|
||||
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
|
||||
}
|
||||
|
||||
// Compile program.
|
||||
bool use_sass = compiler_supports_device_sass(device);
|
||||
std::string compute = fmt::format(
|
||||
"--gpu-architecture={}_{}{}",
|
||||
use_sass ? "sm" : "compute",
|
||||
device.compute_capability_major(),
|
||||
device.compute_capability_minor());
|
||||
std::string include = fmt::format("--include-path={}/include", cuda_home());
|
||||
const char* args[] = {compute.c_str(), include.c_str()};
|
||||
nvrtcResult compile_result =
|
||||
nvrtcCompileProgram(prog, std::size(args), args);
|
||||
if (compile_result != NVRTC_SUCCESS) {
|
||||
size_t log_size;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
|
||||
std::vector<char> log(log_size + 1, 0);
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
|
||||
throw std::runtime_error(
|
||||
fmt::format("Failed to compile kernel: {}.", log.data()));
|
||||
}
|
||||
|
||||
// Get mangled names of kernel names.
|
||||
for (const auto& name : kernel_names) {
|
||||
const char* mangled;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
|
||||
ptx_kernels.emplace_back(name, mangled);
|
||||
}
|
||||
|
||||
// Get ptx data.
|
||||
size_t ptx_size;
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
||||
}
|
||||
ptx.resize(ptx_size, 0);
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
|
||||
}
|
||||
|
||||
// Load module.
|
||||
char jit_log[4089] = {};
|
||||
CUjit_option options[] = {
|
||||
CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
|
||||
void* values[] = {jit_log, reinterpret_cast<void*>(std::size(jit_log) - 1)};
|
||||
CUresult jit_result = cuModuleLoadDataEx(
|
||||
&module_, ptx.data(), std::size(options), options, values);
|
||||
if (jit_result != CUDA_SUCCESS) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Failed to load compiled {} kernel: {}.", module_name, jit_log));
|
||||
}
|
||||
|
||||
// Load kernels.
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
CUfunction kernel;
|
||||
CHECK_CU_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||
kernels_[name] = kernel;
|
||||
}
|
||||
}
|
||||
|
||||
JitModule::~JitModule() {
|
||||
CHECK_CU_ERROR(cuModuleUnload(module_));
|
||||
}
|
||||
|
||||
void JitModule::launch_kernel(
|
||||
CUstream stream,
|
||||
const std::string& kernel_name,
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread) {
|
||||
CUfunction kernel = get_kernel(kernel_name);
|
||||
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
|
||||
int _, block_dim;
|
||||
CHECK_CU_ERROR(
|
||||
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
|
||||
if (block_dim > nthreads) {
|
||||
block_dim = nthreads;
|
||||
}
|
||||
Dims num_blocks{1, 1, 1};
|
||||
if (large) {
|
||||
num_blocks =
|
||||
get_2d_grid_dims_common(arr.shape(), arr.strides(), work_per_thread);
|
||||
std::get<0>(num_blocks) =
|
||||
(std::get<0>(num_blocks) + block_dim - 1) / block_dim;
|
||||
} else {
|
||||
std::get<0>(num_blocks) = (nthreads + block_dim - 1) / block_dim;
|
||||
}
|
||||
launch_kernel(stream, kernel, num_blocks, Dims{block_dim, 1, 1});
|
||||
}
|
||||
|
||||
void JitModule::launch_kernel(
|
||||
CUstream stream,
|
||||
CUfunction kernel,
|
||||
Dims num_blocks,
|
||||
Dims block_dims) {
|
||||
CHECK_CU_ERROR(cuLaunchKernel(
|
||||
kernel,
|
||||
std::get<0>(num_blocks),
|
||||
std::get<1>(num_blocks),
|
||||
std::get<2>(num_blocks),
|
||||
std::get<0>(block_dims),
|
||||
std::get<1>(block_dims),
|
||||
std::get<2>(block_dims),
|
||||
0,
|
||||
stream,
|
||||
args_.data(),
|
||||
nullptr));
|
||||
args_.clear();
|
||||
storage_.clear();
|
||||
}
|
||||
|
||||
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
|
||||
auto it = kernels_.find(kernel_name);
|
||||
if (it == kernels_.end()) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("There is no kernel named {}.", kernel_name));
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void JitModule::append_ptr_arg(const void* v) {
|
||||
args_.push_back(const_cast<void*>(v));
|
||||
}
|
||||
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder) {
|
||||
static std::unordered_map<std::string, JitModule> map;
|
||||
auto it = map.find(name);
|
||||
if (it == map.end()) {
|
||||
it = map.try_emplace(name, cu::device(device), name, builder).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
113
mlx/backend/cuda/jit_module.h
Normal file
113
mlx/backend/cuda/jit_module.h
Normal file
@ -0,0 +1,113 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
|
||||
#include <deque>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Device;
|
||||
|
||||
using KernelBuilderResult = std::pair<
|
||||
/* source code */ std::string,
|
||||
/* kernel names */ std::vector<std::string>>;
|
||||
using KernelBuilder = std::function<KernelBuilderResult()>;
|
||||
|
||||
class JitModule {
|
||||
public:
|
||||
JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder);
|
||||
~JitModule();
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
JitModule& operator=(const JitModule&) = delete;
|
||||
|
||||
void append_arg(const array& a) {
|
||||
append_arg(reinterpret_cast<CUdeviceptr>(a.data<void>()));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append_arg(T val) {
|
||||
storage_.emplace_back(val);
|
||||
append_ptr_arg(&storage_.back());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append_arg(std::vector<T> vec) {
|
||||
if (vec.empty()) {
|
||||
// The nullptr can not be used as arg, pass something not null.
|
||||
append_arg(std::monostate{});
|
||||
} else {
|
||||
append_ptr_arg(vec.data());
|
||||
storage_.emplace_back(std::move(vec));
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure the arg is copied to an array with size of NDIM.
|
||||
template <size_t NDIM = MAX_NDIM, typename T>
|
||||
void append_ndim_arg(const std::vector<T>& vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
}
|
||||
std::vector<T> copied(NDIM);
|
||||
std::copy(vec.begin(), vec.end(), copied.data());
|
||||
append_arg(std::move(copied));
|
||||
}
|
||||
|
||||
// Launch kernel with |kernel_name| that each thread works on
|
||||
// |work_per_thread| elements of |arr|.
|
||||
void launch_kernel(
|
||||
CUstream stream,
|
||||
const std::string& kernel_name,
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread = 1);
|
||||
|
||||
void launch_kernel(
|
||||
CUstream stream,
|
||||
CUfunction kernel,
|
||||
Dims num_blocks,
|
||||
Dims block_dims);
|
||||
|
||||
CUfunction get_kernel(const std::string& kernel_name);
|
||||
|
||||
private:
|
||||
void append_ptr_arg(const void* v);
|
||||
|
||||
CUmodule module_{nullptr};
|
||||
std::unordered_map<std::string, CUfunction> kernels_;
|
||||
std::vector<void*> args_;
|
||||
|
||||
// The cuLaunchKernel API requires passing pointers to arguments so store
|
||||
// temporary values untill kernel is launched.
|
||||
using Arg = std::variant<
|
||||
std::monostate,
|
||||
CUdeviceptr,
|
||||
int32_t,
|
||||
uint32_t,
|
||||
int64_t,
|
||||
std::vector<const void*>,
|
||||
std::vector<int32_t>,
|
||||
std::vector<int64_t>>;
|
||||
std::deque<Arg> storage_;
|
||||
};
|
||||
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder);
|
||||
|
||||
} // namespace mlx::core::cu
|
@ -23,4 +23,11 @@ dim3 get_2d_grid_dims(
|
||||
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||
}
|
||||
|
||||
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2) {
|
||||
auto [grid, block] = get_grid_and_block_common(dim0, dim1, dim2);
|
||||
auto [gx, gy, gz] = grid;
|
||||
auto [bx, by, bz] = block;
|
||||
return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,19 +1,77 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file includes host-only utilies for writing CUDA kernels, the difference
|
||||
// from backend/cuda/kernels/utils.cuh is that the latter file only include
|
||||
// from backend/cuda/device/utils.cuh is that the latter file only include
|
||||
// device-only code.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <fmt/format.h>
|
||||
#include <cuda/cmath>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Convert a number between 1~3 to constexpr.
|
||||
#define MLX_SWITCH_1_2_3(N, NDIM, ...) \
|
||||
switch (N) { \
|
||||
case 1: { \
|
||||
constexpr int NDIM = 1; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case 2: { \
|
||||
constexpr int NDIM = 2; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case 3: { \
|
||||
constexpr int NDIM = 3; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
}
|
||||
|
||||
// Like MLX_SWITCH_ALL_TYPES but for booleans.
|
||||
#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \
|
||||
if (BOOL) { \
|
||||
constexpr bool BOOL_ALIAS = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool BOOL_ALIAS = false; \
|
||||
__VA_ARGS__; \
|
||||
}
|
||||
|
||||
// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2.
|
||||
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \
|
||||
{ \
|
||||
uint32_t _num_threads = NUM_THREADS; \
|
||||
if (_num_threads <= WARP_SIZE) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 2) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 4) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 8) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 16) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
}
|
||||
|
||||
// Maps CPU types to CUDA types.
|
||||
template <typename T>
|
||||
struct CTypeToCudaType {
|
||||
@ -38,6 +96,29 @@ struct CTypeToCudaType<complex64_t> {
|
||||
template <typename T>
|
||||
using cuda_type_t = typename CTypeToCudaType<T>::type;
|
||||
|
||||
// Type traits for detecting floating numbers.
|
||||
template <typename T>
|
||||
inline constexpr bool is_floating_v =
|
||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
||||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
||||
|
||||
// Type traits for detecting complex or real floating point numbers.
|
||||
template <typename T>
|
||||
inline constexpr bool is_inexact_v =
|
||||
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
|
||||
|
||||
// Utility to copy data from vector to array in host.
|
||||
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
}
|
||||
cuda::std::array<T, NDIM> result;
|
||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||
return result;
|
||||
}
|
||||
|
||||
// Compute the grid and block dimensions, check backend/common/utils.h for docs.
|
||||
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);
|
||||
@ -45,5 +126,49 @@ dim3 get_2d_grid_dims(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor);
|
||||
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
|
||||
|
||||
// Return a block size that achieves maximum potential occupancy for kernel.
|
||||
template <typename T>
|
||||
inline uint max_occupancy_block_dim(T kernel) {
|
||||
int _, block_dim;
|
||||
CHECK_CUDA_ERROR(cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
|
||||
return block_dim;
|
||||
}
|
||||
|
||||
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
|
||||
// assuming each thread handles |work_per_thread| elements of |arr|.
|
||||
template <typename T>
|
||||
inline std::tuple<dim3, uint> get_launch_args(
|
||||
T kernel,
|
||||
size_t size,
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
bool large,
|
||||
int work_per_thread = 1) {
|
||||
size_t nthreads = cuda::ceil_div(size, work_per_thread);
|
||||
uint block_dim = max_occupancy_block_dim(kernel);
|
||||
if (block_dim > nthreads) {
|
||||
block_dim = nthreads;
|
||||
}
|
||||
dim3 num_blocks;
|
||||
if (large) {
|
||||
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
|
||||
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
|
||||
} else {
|
||||
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
|
||||
}
|
||||
return std::make_tuple(num_blocks, block_dim);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::tuple<dim3, uint> get_launch_args(
|
||||
T kernel,
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread = 1) {
|
||||
return get_launch_args(
|
||||
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,76 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/type_traits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Additional C++ operator overrides between half types and native types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_integral_except =
|
||||
cuda::std::is_integral_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_arithmetic_except =
|
||||
cuda::std::is_arithmetic_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(HALF x, T y) { \
|
||||
return FLOAT2HALF(HALF2FLOAT(x) OP static_cast<float>(y)); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(T x, HALF y) { \
|
||||
return FLOAT2HALF(static_cast<float>(x) OP HALF2FLOAT(y)); \
|
||||
}
|
||||
|
||||
#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(HALF x, T y) { \
|
||||
return HALF2FLOAT(x) OP static_cast<float>(y); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(T x, HALF y) { \
|
||||
return static_cast<float>(y) OP HALF2FLOAT(x); \
|
||||
}
|
||||
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, !=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=)
|
||||
|
||||
#undef MLX_DEFINE_HALF_OP
|
||||
#undef MLX_DEFINE_HALF_CMP
|
||||
|
||||
} // namespace mlx::core::cu
|
389
mlx/backend/cuda/layer_norm.cu
Normal file
389
mlx/backend/cuda/layer_norm.cu
Normal file
@ -0,0 +1,389 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
inline __device__ float3 plus_f3(const float3& a, const float3& b) {
|
||||
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
||||
}
|
||||
|
||||
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
||||
template <typename T, int BLOCK_DIM>
|
||||
struct BlockBroadcastReduce {
|
||||
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
||||
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
||||
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
||||
|
||||
cg::thread_block& block;
|
||||
TempStorage& temp;
|
||||
|
||||
template <typename Op>
|
||||
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
T x = cg::reduce(warp, input, op);
|
||||
if (warp.thread_rank() == 0) {
|
||||
temp[warp.meta_group_rank()] = x;
|
||||
}
|
||||
block.sync();
|
||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||
: init_value;
|
||||
return cg::reduce(warp, x, op);
|
||||
}
|
||||
|
||||
__device__ T Sum(const T& input) {
|
||||
return Reduce(input, cg::plus<T>{}, T{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void layer_norm(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* b,
|
||||
T* out,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride,
|
||||
int64_t b_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
// Sum.
|
||||
float sum = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
||||
}
|
||||
sum = BlockReduceT{block, temp}.Sum(sum);
|
||||
|
||||
// Mean.
|
||||
float mean = sum / axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
normalizer += t * t;
|
||||
}
|
||||
}
|
||||
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T bn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
||||
}
|
||||
cub::StoreDirectBlocked(index, out, xn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void layer_norm_vjp(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* g,
|
||||
T* gx,
|
||||
T* gw,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
using BlockReduceF3 = BlockBroadcastReduce<float3, BLOCK_DIM>;
|
||||
__shared__ union {
|
||||
typename BlockReduceF::TempStorage f;
|
||||
typename BlockReduceF3::TempStorage f3;
|
||||
} temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
g += grid.block_rank() * axis_size;
|
||||
gx += grid.block_rank() * axis_size;
|
||||
gw += grid.block_rank() * axis_size;
|
||||
|
||||
// Sum.
|
||||
float sum = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
||||
}
|
||||
sum = BlockReduceF{block, temp.f}.Sum(sum);
|
||||
|
||||
// Mean.
|
||||
float mean = sum / axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float3 factors = {};
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T xn[N_READS];
|
||||
T wn[N_READS] = {};
|
||||
T gn[N_READS] = {};
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f3(factors, {wg, wg * t, t * t});
|
||||
}
|
||||
}
|
||||
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
|
||||
float meanwg = factors.x / axis_size;
|
||||
float meanwgxc = factors.y / axis_size;
|
||||
float normalizer2 = 1 / (factors.z / axis_size + eps);
|
||||
float normalizer = sqrt(normalizer2);
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T gn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
xn[i] = normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2;
|
||||
if constexpr (HAS_W) {
|
||||
wn[i] = gi * xi;
|
||||
}
|
||||
}
|
||||
cub::StoreDirectBlocked(index, gx, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
cub::StoreDirectBlocked(index, gw, wn, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool LayerNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||
void LayerNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("LayerNorm::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
const array x = set_output(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::layer_norm<DataType, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride,
|
||||
b_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void LayerNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("LayerNormVJP::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||
if (x.flags().row_contiguous) {
|
||||
return {x, false};
|
||||
}
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
return {x_copy, true};
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[3].is_donatable();
|
||||
auto [x, copied] = check_input(inputs[0]);
|
||||
donate_x |= copied;
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
auto [g, g_copied] = check_input(inputs[3]);
|
||||
donate_g |= g_copied;
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
array& gb = outputs[2];
|
||||
|
||||
// Check whether we had a weight.
|
||||
bool has_w = w.ndim() != 0;
|
||||
|
||||
// Allocate space for the outputs.
|
||||
bool g_in_gx = false;
|
||||
if (donate_x) {
|
||||
gx.copy_shared_buffer(x);
|
||||
} else if (donate_g) {
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
encoder.add_temporary(g);
|
||||
}
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and allocate the output
|
||||
// gradient accumulators.
|
||||
array gw_temp =
|
||||
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||
if (has_w) {
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc(gb.nbytes()));
|
||||
|
||||
// Finish with the gradient for b in case we had a b.
|
||||
if (gb.ndim() == 1 && gb.size() == axis_size) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(g);
|
||||
encoder.set_output_array(gx);
|
||||
encoder.set_output_array(gw_temp);
|
||||
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::layer_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
if (has_w) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
159
mlx/backend/cuda/logsumexp.cu
Normal file
159
mlx/backend/cuda/logsumexp.cu
Normal file
@ -0,0 +1,159 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||
return __expf(x);
|
||||
}
|
||||
|
||||
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void logsumexp(const T* in, T* out, int axis_size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
in += grid.block_rank() * axis_size;
|
||||
|
||||
cg::greater<AccT> max_op;
|
||||
cg::plus<AccT> plus_op;
|
||||
|
||||
// Thread reduce.
|
||||
AccT prevmax;
|
||||
AccT maxval = Limits<AccT>::finite_min();
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
AccT vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r * BLOCK_DIM + block.thread_rank(),
|
||||
make_cast_iterator<AccT>(in),
|
||||
vals,
|
||||
axis_size,
|
||||
Limits<AccT>::min());
|
||||
prevmax = maxval;
|
||||
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||
// Online normalizer calculation for softmax:
|
||||
// https://github.com/NVIDIA/online-softmax
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
||||
}
|
||||
}
|
||||
|
||||
// First warp reduce.
|
||||
prevmax = maxval;
|
||||
maxval = cg::reduce(warp, maxval, max_op);
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||
|
||||
__shared__ AccT local_max[WARP_SIZE];
|
||||
__shared__ AccT local_normalizer[WARP_SIZE];
|
||||
|
||||
// Write to shared memory and do second warp reduce.
|
||||
prevmax = maxval;
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_max[warp.meta_group_rank()] = maxval;
|
||||
}
|
||||
block.sync();
|
||||
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_max[warp.thread_rank()]
|
||||
: Limits<AccT>::finite_min();
|
||||
maxval = cg::reduce(warp, maxval, max_op);
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_normalizer[warp.meta_group_rank()] = normalizer;
|
||||
}
|
||||
block.sync();
|
||||
normalizer = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_normalizer[warp.thread_rank()]
|
||||
: AccT{};
|
||||
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||
|
||||
// Write output.
|
||||
if (block.thread_rank() == 0) {
|
||||
out[grid.block_rank()] = isinf(maxval) ? maxval : log(normalizer) + maxval;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("LogSumExp::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto in = ensure_contiguous(inputs[0]);
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
} else {
|
||||
auto n = in.shape(-1);
|
||||
auto flags = in.flags();
|
||||
auto strides = in.strides();
|
||||
for (auto& s : strides) {
|
||||
s /= n;
|
||||
}
|
||||
bool col_contig = strides[0] == 1;
|
||||
for (int i = 1; col_contig && i < strides.size(); ++i) {
|
||||
col_contig &=
|
||||
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
|
||||
}
|
||||
flags.col_contiguous = col_contig;
|
||||
out.set_data(
|
||||
allocator::malloc(in.nbytes() / n),
|
||||
in.data_size() / n,
|
||||
std::move(strides),
|
||||
flags);
|
||||
}
|
||||
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::logsumexp<DataType, float, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
496
mlx/backend/cuda/matmul.cpp
Normal file
496
mlx/backend/cuda/matmul.cpp
Normal file
@ -0,0 +1,496 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/matmul.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||
|
||||
void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||
if (err != CUBLAS_STATUS_SUCCESS) {
|
||||
// TODO: Use cublasGetStatusString when it is widely available.
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
|
||||
}
|
||||
}
|
||||
|
||||
class MatMul {
|
||||
public:
|
||||
MatMul(
|
||||
Device& device,
|
||||
Dtype dtype,
|
||||
bool a_transposed,
|
||||
uint64_t a_rows,
|
||||
uint64_t a_cols,
|
||||
int64_t lda,
|
||||
bool b_transposed,
|
||||
uint64_t b_rows,
|
||||
uint64_t b_cols,
|
||||
int64_t ldb,
|
||||
int32_t batch_count,
|
||||
int64_t a_batch_stride,
|
||||
int64_t b_batch_stride) {
|
||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||
|
||||
auto scale_type = dtype_to_cuda_type(dtype);
|
||||
if (dtype == bfloat16 || dtype == float16) {
|
||||
scale_type = CUDA_R_32F;
|
||||
}
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
||||
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
|
||||
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
||||
&pointer_mode,
|
||||
sizeof(int32_t)));
|
||||
cublasOperation_t op = CUBLAS_OP_N;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
&op,
|
||||
sizeof(cublasOperation_t)));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_TRANSB,
|
||||
&op,
|
||||
sizeof(cublasOperation_t)));
|
||||
|
||||
auto type = dtype_to_cuda_type(dtype);
|
||||
a_desc_ = create_matrix_layout(
|
||||
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
||||
b_desc_ = create_matrix_layout(
|
||||
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
||||
out_desc_ = create_matrix_layout(
|
||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||
|
||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||
// for Hopper+:
|
||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||
uint64_t MiB = 1024 * 1024;
|
||||
uint64_t workspace_size =
|
||||
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
||||
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
||||
pref_,
|
||||
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&workspace_size,
|
||||
sizeof(uint64_t)));
|
||||
}
|
||||
|
||||
MatMul(
|
||||
Device& device,
|
||||
Dtype dtype,
|
||||
bool a_transposed,
|
||||
uint64_t a_rows,
|
||||
uint64_t a_cols,
|
||||
int64_t lda,
|
||||
bool b_transposed,
|
||||
uint64_t b_rows,
|
||||
uint64_t b_cols,
|
||||
int64_t ldb,
|
||||
bool c_transposed,
|
||||
int64_t ldc,
|
||||
int32_t batch_count,
|
||||
int64_t a_batch_stride,
|
||||
int64_t b_batch_stride,
|
||||
int64_t c_batch_stride)
|
||||
: MatMul(
|
||||
device,
|
||||
dtype,
|
||||
a_transposed,
|
||||
a_rows,
|
||||
a_cols,
|
||||
lda,
|
||||
b_transposed,
|
||||
b_rows,
|
||||
b_cols,
|
||||
ldb,
|
||||
batch_count,
|
||||
a_batch_stride,
|
||||
b_batch_stride) {
|
||||
auto type = dtype_to_cuda_type(dtype);
|
||||
c_desc_ = create_matrix_layout(
|
||||
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
|
||||
}
|
||||
|
||||
~MatMul() {
|
||||
cublasLtMatrixLayoutDestroy(a_desc_);
|
||||
cublasLtMatrixLayoutDestroy(b_desc_);
|
||||
cublasLtMatrixLayoutDestroy(c_desc_);
|
||||
cublasLtMatrixLayoutDestroy(out_desc_);
|
||||
cublasLtMatmulDescDestroy(matmul_desc_);
|
||||
}
|
||||
|
||||
void run(
|
||||
cu::CommandEncoder& encoder,
|
||||
void* out,
|
||||
void* a,
|
||||
void* b,
|
||||
void* c = nullptr,
|
||||
float alpha = 1,
|
||||
float beta = 0) {
|
||||
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
|
||||
int ret = 0;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
|
||||
encoder.device().lt_handle(),
|
||||
matmul_desc_,
|
||||
a_desc_,
|
||||
b_desc_,
|
||||
out_desc_,
|
||||
out_desc_,
|
||||
pref_,
|
||||
1,
|
||||
&heuristic_,
|
||||
&ret));
|
||||
if (ret == 0) {
|
||||
throw std::runtime_error("Can not find algorithm for matmul.");
|
||||
}
|
||||
}
|
||||
|
||||
void* workspace_ptr = nullptr;
|
||||
if (heuristic_.workspaceSize > 0) {
|
||||
array workspace(
|
||||
allocator::malloc(heuristic_.workspaceSize),
|
||||
{static_cast<int>(heuristic_.workspaceSize)},
|
||||
int8);
|
||||
encoder.add_temporary(workspace);
|
||||
workspace_ptr = workspace.data<void>();
|
||||
}
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||
encoder.device().lt_handle(),
|
||||
matmul_desc_,
|
||||
&alpha,
|
||||
a,
|
||||
a_desc_,
|
||||
b,
|
||||
b_desc_,
|
||||
&beta,
|
||||
c ? c : out,
|
||||
c ? c_desc_ : out_desc_,
|
||||
out,
|
||||
out_desc_,
|
||||
&heuristic_.algo,
|
||||
workspace_ptr,
|
||||
heuristic_.workspaceSize,
|
||||
stream));
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case float16:
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
case bfloat16:
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
case float32:
|
||||
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
|
||||
: CUBLAS_COMPUTE_32F;
|
||||
case float64:
|
||||
case complex64:
|
||||
return CUBLAS_COMPUTE_64F;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case float16:
|
||||
return CUDA_R_16F;
|
||||
case bfloat16:
|
||||
return CUDA_R_16BF;
|
||||
case float32:
|
||||
return CUDA_R_32F;
|
||||
case float64:
|
||||
return CUDA_R_64F;
|
||||
case complex64:
|
||||
return CUDA_C_32F;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
cublasLtMatrixLayout_t create_matrix_layout(
|
||||
cudaDataType_t type,
|
||||
uint64_t rows,
|
||||
uint64_t cols,
|
||||
bool transposed,
|
||||
int64_t ld,
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride) {
|
||||
cublasLtMatrixLayout_t desc;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
|
||||
cublasLtOrder_t order =
|
||||
transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
|
||||
if (batch_count > 1) {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&batch_count,
|
||||
sizeof(int32_t)));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||
&batch_stride,
|
||||
sizeof(int64_t)));
|
||||
}
|
||||
return desc;
|
||||
}
|
||||
|
||||
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
||||
cublasLtMatmulPreference_t pref_{nullptr};
|
||||
cublasLtMatrixLayout_t a_desc_{nullptr};
|
||||
cublasLtMatrixLayout_t b_desc_{nullptr};
|
||||
cublasLtMatrixLayout_t c_desc_{nullptr};
|
||||
cublasLtMatrixLayout_t out_desc_{nullptr};
|
||||
cublasLtMatmulHeuristicResult_t heuristic_;
|
||||
};
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace {
|
||||
|
||||
std::tuple<bool, int64_t, array>
|
||||
check_transpose(std::vector<array>& copies, const Stream& s, const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && stx == arr.shape(-1)) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Matmul::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
assert(inputs.size() == 2);
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
// Return 0s if either input is empty.
|
||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||
array zero(0, a_pre.dtype());
|
||||
encoder.add_temporary(zero);
|
||||
fill_gpu(zero, out, s);
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
|
||||
|
||||
for (auto& temp : copies) {
|
||||
encoder.add_temporary(temp);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||
|
||||
auto batch_count = out.size() / (M * N);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||
b_batch_strides.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_count = 1;
|
||||
|
||||
a_batch_strides = {0};
|
||||
b_batch_strides = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
|
||||
cu::MatMul matmul(
|
||||
encoder.device(),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
M,
|
||||
K,
|
||||
lda,
|
||||
b_transposed,
|
||||
K,
|
||||
N,
|
||||
ldb,
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
auto nbatch = batch_count / batch_shape.back();
|
||||
if (nbatch == 1) {
|
||||
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||
for (size_t i = 0; i < nbatch; ++i) {
|
||||
matmul.run(
|
||||
encoder,
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||
b.data<int8_t>() + b.itemsize() * b_it.loc);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("AddMM::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
assert(inputs.size() == 3);
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
auto& c_pre = inputs[2];
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
|
||||
auto [c_transposed, ldc, c] = check_transpose(copies, s, c_pre);
|
||||
|
||||
for (auto& temp : copies) {
|
||||
encoder.add_temporary(temp);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
auto [batch_shape, a_batch_strides, b_batch_strides, c_batch_strides] =
|
||||
collapse_batches(a, b, c);
|
||||
|
||||
auto batch_count = out.size() / (M * N);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||
c_batch_strides.back() == M * c.strides()[c.ndim() - 2] &&
|
||||
b_batch_strides.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_count = 1;
|
||||
|
||||
a_batch_strides = {0};
|
||||
b_batch_strides = {0};
|
||||
c_batch_strides = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
|
||||
cu::MatMul matmul(
|
||||
encoder.device(),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
M,
|
||||
K,
|
||||
lda,
|
||||
b_transposed,
|
||||
K,
|
||||
N,
|
||||
ldb,
|
||||
c_transposed,
|
||||
ldc,
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back(),
|
||||
c_batch_strides.back());
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto nbatch = batch_count / batch_shape.back();
|
||||
if (nbatch == 1) {
|
||||
matmul.run(
|
||||
encoder,
|
||||
out.data<int8_t>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
c.data<int8_t>(),
|
||||
alpha_,
|
||||
beta_);
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||
for (size_t i = 0; i < nbatch; ++i) {
|
||||
matmul.run(
|
||||
encoder,
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
||||
c.data<int8_t>() + c.itemsize() * c_it.loc,
|
||||
alpha_,
|
||||
beta_);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
c_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
11
mlx/backend/cuda/no_cuda.cpp
Normal file
11
mlx/backend/cuda/no_cuda.cpp
Normal file
@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
@ -1,9 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/arange.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/arange.cuh"
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
@ -71,101 +71,26 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
NO_GPU(Abs)
|
||||
NO_GPU(Add)
|
||||
NO_GPU(AddMM)
|
||||
NO_GPU(ArcCos)
|
||||
NO_GPU(ArcCosh)
|
||||
NO_GPU(ArcSin)
|
||||
NO_GPU(ArcSinh)
|
||||
NO_GPU(ArcTan)
|
||||
NO_GPU(ArcTan2)
|
||||
NO_GPU(ArcTanh)
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(ArgReduce)
|
||||
NO_GPU(ArgSort)
|
||||
NO_GPU(BitwiseBinary)
|
||||
NO_GPU(BitwiseInvert)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Ceil)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Conjugate)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU(Cos)
|
||||
NO_GPU(Cosh)
|
||||
NO_GPU(Divide)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(Remainder)
|
||||
NO_GPU(Equal)
|
||||
NO_GPU(Erf)
|
||||
NO_GPU(ErfInv)
|
||||
NO_GPU(Exp)
|
||||
NO_GPU(Expm1)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Floor)
|
||||
NO_GPU(Gather)
|
||||
NO_GPU(GatherAxis)
|
||||
NO_GPU(GatherMM)
|
||||
NO_GPU(GatherQMM)
|
||||
NO_GPU(Greater)
|
||||
NO_GPU(GreaterEqual)
|
||||
NO_GPU(Hadamard)
|
||||
NO_GPU(Imag)
|
||||
NO_GPU(Less)
|
||||
NO_GPU(LessEqual)
|
||||
NO_GPU(Load)
|
||||
NO_GPU(Log)
|
||||
NO_GPU(Log1p)
|
||||
NO_GPU(LogicalNot)
|
||||
NO_GPU(LogicalAnd)
|
||||
NO_GPU(LogicalOr)
|
||||
NO_GPU(LogAddExp)
|
||||
NO_GPU(LogSumExp)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Matmul)
|
||||
NO_GPU(Maximum)
|
||||
NO_GPU(Minimum)
|
||||
NO_GPU(Multiply)
|
||||
NO_GPU(Negative)
|
||||
NO_GPU(NotEqual)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU(Power)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(RandomBits)
|
||||
NO_GPU(Real)
|
||||
NO_GPU(Reduce)
|
||||
NO_GPU(Round)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(ScatterAxis)
|
||||
NO_GPU(Select)
|
||||
NO_GPU(Sigmoid)
|
||||
NO_GPU(Sign)
|
||||
NO_GPU(Sin)
|
||||
NO_GPU(Sinh)
|
||||
NO_GPU(SliceUpdate)
|
||||
NO_GPU(Softmax)
|
||||
NO_GPU(Sort)
|
||||
NO_GPU(Square)
|
||||
NO_GPU(Sqrt)
|
||||
NO_GPU(Subtract)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_USE_FALLBACK(LayerNorm)
|
||||
NO_GPU_MULTI(LayerNormVJP)
|
||||
NO_GPU_USE_FALLBACK(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
|
189
mlx/backend/cuda/random.cu
Normal file
189
mlx/backend/cuda/random.cu
Normal file
@ -0,0 +1,189 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
__constant__ constexpr uint32_t rotations[2][4] = {
|
||||
{13, 15, 26, 6},
|
||||
{17, 29, 16, 24}};
|
||||
|
||||
union rbits {
|
||||
uint2 val;
|
||||
uint8_t bytes[2][4];
|
||||
};
|
||||
|
||||
__device__ rbits threefry2x32_hash(uint2 key, uint2 count) {
|
||||
uint32_t ks[] = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
|
||||
|
||||
rbits v;
|
||||
v.val.x = count.x + ks[0];
|
||||
v.val.y = count.y + ks[1];
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
for (auto r : rotations[i % 2]) {
|
||||
v.val.x += v.val.y;
|
||||
v.val.y = (v.val.y << r) | (v.val.y >> (32 - r));
|
||||
v.val.y ^= v.val.x;
|
||||
}
|
||||
v.val.x += ks[(i + 1) % 3];
|
||||
v.val.y += ks[(i + 2) % 3] + i + 1;
|
||||
}
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
__global__ void rbitsc(
|
||||
const uint32_t* keys,
|
||||
uint8_t* out,
|
||||
dim3 grid_dims,
|
||||
bool odd,
|
||||
uint32_t bytes_per_key) {
|
||||
auto grid = cg::this_grid();
|
||||
uint thread_index = grid.thread_rank();
|
||||
uint index_x = thread_index % grid_dims.x;
|
||||
uint index_y = thread_index / grid_dims.x;
|
||||
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto kidx = 2 * index_x;
|
||||
auto key = uint2{keys[kidx], keys[kidx + 1]};
|
||||
auto half_size = grid_dims.y - odd;
|
||||
out += index_x * bytes_per_key;
|
||||
bool drop_last = odd && (index_y == half_size);
|
||||
auto bits = threefry2x32_hash(
|
||||
key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
|
||||
size_t idx = size_t(index_y) << 2;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[idx + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
|
||||
if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void rbits(
|
||||
const uint32_t* keys,
|
||||
uint8_t* out,
|
||||
dim3 grid_dims,
|
||||
bool odd,
|
||||
uint32_t bytes_per_key,
|
||||
int32_t ndim,
|
||||
const __grid_constant__ Shape key_shape,
|
||||
const __grid_constant__ Strides key_strides) {
|
||||
auto grid = cg::this_grid();
|
||||
uint thread_index = grid.thread_rank();
|
||||
uint index_x = thread_index % grid_dims.x;
|
||||
uint index_y = thread_index / grid_dims.x;
|
||||
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto kidx = 2 * index_x;
|
||||
auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim);
|
||||
auto k2_elem =
|
||||
elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim);
|
||||
auto key = uint2{keys[k1_elem], keys[k2_elem]};
|
||||
auto half_size = grid_dims.y - odd;
|
||||
out += size_t(index_x) * bytes_per_key;
|
||||
bool drop_last = odd && (index_y == half_size);
|
||||
auto bits = threefry2x32_hash(
|
||||
key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
|
||||
size_t idx = size_t(index_y) << 2;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[idx + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
|
||||
if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("RandomBits::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// keys has shape (N1, ..., NK, 2)
|
||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||
auto& keys = inputs[0];
|
||||
uint32_t num_keys = keys.size() / 2;
|
||||
|
||||
uint32_t elems_per_key = out.size() / num_keys;
|
||||
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||
uint32_t half_size = out_per_key / 2;
|
||||
bool odd = out_per_key % 2;
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(keys);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dim3 grid_dims{num_keys, half_size + odd};
|
||||
int64_t total = grid_dims.x * grid_dims.y;
|
||||
int32_t threads_y = 1;
|
||||
while ((total / threads_y) >= (1U << 31)) {
|
||||
threads_y *= 2;
|
||||
}
|
||||
int32_t threads_x = cuda::ceil_div(total, threads_y);
|
||||
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
|
||||
if (keys.flags().row_contiguous) {
|
||||
cu::rbitsc<<<grid, block, 0, stream>>>(
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
grid_dims,
|
||||
odd,
|
||||
bytes_per_key);
|
||||
} else {
|
||||
cu::rbits<<<grid, block, 0, stream>>>(
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
grid_dims,
|
||||
odd,
|
||||
bytes_per_key,
|
||||
keys.ndim(),
|
||||
const_param(keys.shape()),
|
||||
const_param(keys.strides()));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
82
mlx/backend/cuda/reduce.cu
Normal file
82
mlx/backend/cuda/reduce.cu
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/fill.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Reduce::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
array in = inputs[0];
|
||||
|
||||
// Make sure no identity reductions trickle down here.
|
||||
assert(!axes_.empty());
|
||||
assert(out.size() != in.size());
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// Fill out with init value.
|
||||
if (in.size() == 0) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
thrust::fill_n(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::device_pointer_cast(out.data<OutType>()),
|
||||
out.data_size(),
|
||||
cu::ReduceInit<OP, InType>::value());
|
||||
});
|
||||
});
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Reduce.
|
||||
ReductionPlan plan = get_reduction_plan(in, axes_);
|
||||
|
||||
// If it is a general reduce then copy the input to a contiguous array and
|
||||
// recompute the plan.
|
||||
if (plan.type == GeneralReduce) {
|
||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, in_copy, CopyType::General, s);
|
||||
encoder.add_temporary(in_copy);
|
||||
in = in_copy;
|
||||
plan = get_reduction_plan(in, axes_);
|
||||
}
|
||||
|
||||
if ((plan.type == ContiguousAllReduce) ||
|
||||
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
|
||||
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
|
||||
row_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousStridedReduce ||
|
||||
plan.type == GeneralStridedReduce) {
|
||||
col_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||
return;
|
||||
}
|
||||
|
||||
throw std::runtime_error("No plan reached in reduce.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
278
mlx/backend/cuda/reduce/col_reduce.cu
Normal file
278
mlx/backend/cuda/reduce/col_reduce.cu
Normal file
@ -0,0 +1,278 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
struct ColReduceArgs {
|
||||
// The size of the contiguous column reduction.
|
||||
size_t reduction_size;
|
||||
int64_t reduction_stride;
|
||||
|
||||
// Input shape and strides excluding the reduction axes.
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
int ndim;
|
||||
|
||||
// Input shape and strides of the reduction axes (including last dimension).
|
||||
Shape reduce_shape;
|
||||
Strides reduce_strides;
|
||||
int reduce_ndim;
|
||||
|
||||
// The number of column we are reducing. Namely prod(reduce_shape).
|
||||
size_t non_col_reductions;
|
||||
|
||||
ColReduceArgs(
|
||||
const array& in,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes) {
|
||||
assert(!plan.shape.empty());
|
||||
reduction_size = plan.shape.back();
|
||||
reduction_stride = plan.strides.back();
|
||||
|
||||
int64_t stride_back = 1;
|
||||
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
||||
stride_back *= shape_vec.back();
|
||||
shape_vec.pop_back();
|
||||
strides_vec.pop_back();
|
||||
}
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||
shape = const_param(shape_vec);
|
||||
strides = const_param(strides_vec);
|
||||
ndim = shape_vec.size();
|
||||
|
||||
reduce_shape = const_param(plan.shape);
|
||||
reduce_strides = const_param(plan.strides);
|
||||
reduce_ndim = plan.shape.size();
|
||||
|
||||
non_col_reductions = 1;
|
||||
for (int i = 0; i < reduce_ndim - 1; i++) {
|
||||
non_col_reductions *= reduce_shape[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void col_reduce_small(
|
||||
const T* in,
|
||||
U* out,
|
||||
const __grid_constant__ ColReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
int column =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
if (column * N_READS >= args.reduction_stride) {
|
||||
return;
|
||||
}
|
||||
|
||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
Op op;
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
// Read input to local.
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(
|
||||
block.thread_index().y,
|
||||
args.reduce_shape.data(),
|
||||
args.reduce_strides.data());
|
||||
for (size_t r = block.thread_index().y;
|
||||
r < args.non_col_reductions * args.reduction_size;
|
||||
r += block.dim_threads().y) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
column,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.reduction_stride,
|
||||
ReduceInit<Op, T>::value());
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
loop.next(
|
||||
block.dim_threads().y,
|
||||
args.reduce_shape.data(),
|
||||
args.reduce_strides.data());
|
||||
}
|
||||
|
||||
// Do block reduce when each column has more than 1 element to reduce.
|
||||
if (block.dim_threads().y > 1) {
|
||||
__shared__ U shared_vals[32 * 8 * N_READS];
|
||||
size_t col =
|
||||
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[col * N_READS + i] = totals[i];
|
||||
}
|
||||
block.sync();
|
||||
if (block.thread_index().y == 0) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
|
||||
}
|
||||
for (int j = 1; j < block.dim_threads().y; j++) {
|
||||
col = j * block.dim_threads().x + block.thread_index().x;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write result.
|
||||
if (block.thread_index().y == 0) {
|
||||
cub::StoreDirectBlocked(
|
||||
column,
|
||||
out + out_idx * args.reduction_stride,
|
||||
totals,
|
||||
args.reduction_stride);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BM,
|
||||
int BN,
|
||||
int N_READS = 4>
|
||||
__global__ void col_reduce_looped(
|
||||
const T* in,
|
||||
U* out,
|
||||
const __grid_constant__ ColReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
constexpr int n_warps = BN / N_READS;
|
||||
|
||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
Op op;
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
// Read input to local.
|
||||
int r = block.thread_rank() / n_warps;
|
||||
int column = block.thread_rank() % n_warps;
|
||||
int in_offset = grid.block_index().x * BN;
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
column,
|
||||
make_cast_iterator<U>(in + loop.location() + in_offset),
|
||||
vals,
|
||||
args.reduction_stride - in_offset,
|
||||
ReduceInit<Op, T>::value());
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
// Do warp reduce for each output.
|
||||
constexpr int n_outputs = BN / n_warps;
|
||||
static_assert(BM == 32 && n_outputs == N_READS);
|
||||
__shared__ U shared_vals[BM * BN];
|
||||
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[col + i] = totals[i];
|
||||
}
|
||||
block.sync();
|
||||
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
|
||||
}
|
||||
|
||||
// Write result.
|
||||
if (warp.thread_rank() == 0) {
|
||||
size_t out_offset = grid.block_index().x * BN;
|
||||
cub::StoreDirectBlocked(
|
||||
warp.meta_group_rank(),
|
||||
out + out_idx * args.reduction_stride + out_offset,
|
||||
totals,
|
||||
args.reduction_stride - out_offset);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
inline auto output_grid_for_col_reduce(
|
||||
const array& out,
|
||||
const cu::ColReduceArgs& args) {
|
||||
auto out_shape = out.shape();
|
||||
auto out_strides = out.strides();
|
||||
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
|
||||
out_shape.pop_back();
|
||||
out_strides.pop_back();
|
||||
}
|
||||
return get_2d_grid_dims(out_shape, out_strides);
|
||||
}
|
||||
|
||||
void col_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
cu::ColReduceArgs args(in, plan, axes);
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
constexpr int N_READS = 4;
|
||||
dim3 block_dims;
|
||||
dim3 num_blocks = output_grid_for_col_reduce(out, args);
|
||||
num_blocks.z = num_blocks.y;
|
||||
num_blocks.y = num_blocks.x;
|
||||
auto kernel =
|
||||
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||
size_t total = args.non_col_reductions * args.reduction_size;
|
||||
if (total < 32) {
|
||||
size_t stride_blocks =
|
||||
cuda::ceil_div(args.reduction_stride, N_READS);
|
||||
block_dims.x = std::min(stride_blocks, 32ul);
|
||||
block_dims.y = std::min(total, 8ul);
|
||||
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
|
||||
} else {
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
block_dims.x = BM * BN / N_READS;
|
||||
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
|
||||
kernel = cu::
|
||||
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(), out.data<OutType>(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
74
mlx/backend/cuda/reduce/reduce.cuh
Normal file
74
mlx/backend/cuda/reduce/reduce.cuh
Normal file
@ -0,0 +1,74 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Dispatch dynamic ndim to constexpr.
|
||||
// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file.
|
||||
#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \
|
||||
if (ndim == 1) { \
|
||||
constexpr uint32_t NDIM = 1; \
|
||||
__VA_ARGS__; \
|
||||
} else if (ndim == 2) { \
|
||||
constexpr uint32_t NDIM = 2; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr uint32_t NDIM = 5; \
|
||||
__VA_ARGS__; \
|
||||
}
|
||||
|
||||
// Dispatch reduce ops to constexpr.
|
||||
#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \
|
||||
if (REDUCE == Reduce::ReduceType::And) { \
|
||||
using OP = cu::And; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Or) { \
|
||||
using OP = cu::Or; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Sum) { \
|
||||
using OP = cu::Sum; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Prod) { \
|
||||
using OP = cu::Prod; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Max) { \
|
||||
using OP = cu::Max; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Min) { \
|
||||
using OP = cu::Min; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
throw std::invalid_argument("Unknown reduce type."); \
|
||||
}
|
||||
|
||||
void segmented_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan);
|
||||
|
||||
void row_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan);
|
||||
|
||||
void col_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan);
|
||||
|
||||
} // namespace mlx::core
|
144
mlx/backend/cuda/reduce/reduce_ops.cuh
Normal file
144
mlx/backend/cuda/reduce/reduce_ops.cuh
Normal file
@ -0,0 +1,144 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Reduce ops.
|
||||
struct And {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
return a && b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Or {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
return a || b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Prod {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Min {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Max {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
// Traits to get the result type of reduce op.
|
||||
template <typename Op, typename T>
|
||||
struct ReduceResult;
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<And, T> {
|
||||
using type = bool;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Or, T> {
|
||||
using type = bool;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Sum, T> {
|
||||
using type = cuda::std::conditional_t<
|
||||
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
|
||||
int32_t,
|
||||
T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Prod, T> {
|
||||
using type = cuda::std::conditional_t<
|
||||
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
|
||||
int32_t,
|
||||
T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Min, T> {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Max, T> {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
// Traits to get the init value of reduce op.
|
||||
template <typename Op, typename T>
|
||||
struct ReduceInit;
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<And, T> {
|
||||
static constexpr __host__ __device__ bool value() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Or, T> {
|
||||
static constexpr __host__ __device__ bool value() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Sum, T> {
|
||||
static constexpr __host__ __device__ auto value() {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return T{0, 0};
|
||||
} else {
|
||||
return typename ReduceResult<Sum, T>::type{0};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Prod, T> {
|
||||
static constexpr __host__ __device__ auto value() {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return T{1, 1};
|
||||
} else {
|
||||
return typename ReduceResult<Prod, T>::type{1};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Min, T> {
|
||||
static constexpr __host__ __device__ T value() {
|
||||
return Limits<T>::max();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Max, T> {
|
||||
static constexpr __host__ __device__ T value() {
|
||||
return Limits<T>::min();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
250
mlx/backend/cuda/reduce/row_reduce.cu
Normal file
250
mlx/backend/cuda/reduce/row_reduce.cu
Normal file
@ -0,0 +1,250 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
struct RowReduceArgs {
|
||||
// The size of the row being reduced, i.e. the size of last dimension.
|
||||
int row_size;
|
||||
|
||||
// Input shape and strides excluding the reduction axes.
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
int ndim;
|
||||
|
||||
// Input shape and strides of the reduction axes excluding last dimension.
|
||||
Shape reduce_shape;
|
||||
Strides reduce_strides;
|
||||
int reduce_ndim;
|
||||
|
||||
// The number of rows we are reducing. Namely prod(reduce_shape).
|
||||
size_t non_row_reductions;
|
||||
|
||||
RowReduceArgs(
|
||||
const array& in,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes) {
|
||||
assert(!plan.shape.empty());
|
||||
row_size = plan.shape.back();
|
||||
|
||||
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||
shape = const_param(shape_vec);
|
||||
strides = const_param(strides_vec);
|
||||
ndim = shape_vec.size();
|
||||
|
||||
reduce_shape = const_param(plan.shape);
|
||||
reduce_strides = const_param(plan.strides);
|
||||
reduce_ndim = plan.shape.size() - 1;
|
||||
|
||||
non_row_reductions = 1;
|
||||
for (int i = 0; i < reduce_ndim; i++) {
|
||||
non_row_reductions *= reduce_shape[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void row_reduce_small(
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
size_t out_idx = cg::this_grid().thread_rank();
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
}
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void row_reduce_small_warp(
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
size_t out_idx = grid.thread_rank() / WARP_SIZE;
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
|
||||
n += WARP_SIZE) {
|
||||
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
}
|
||||
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
total_val = cg::reduce(warp, total_val, op);
|
||||
|
||||
if (warp.thread_rank() == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BLOCK_DIM_X,
|
||||
int N_READS = 4>
|
||||
__global__ void row_reduce_looped(
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS);
|
||||
r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r * BLOCK_DIM_X + block.thread_index().x,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
}
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
total_val = BlockReduceT(temp).Reduce(total_val, op);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void row_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
cu::RowReduceArgs args(in, plan, axes);
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
constexpr size_t N_READS = 4;
|
||||
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block_dims, num_blocks;
|
||||
auto kernel =
|
||||
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||
if (args.row_size <= 64) {
|
||||
if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
|
||||
(args.non_row_reductions <= 8)) {
|
||||
block_dims.x = std::min(out_dims.x, 1024u);
|
||||
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
|
||||
num_blocks.y = out_dims.y;
|
||||
} else {
|
||||
block_dims.x = WARP_SIZE;
|
||||
num_blocks.y = out_dims.x;
|
||||
num_blocks.z = out_dims.y;
|
||||
kernel =
|
||||
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>;
|
||||
}
|
||||
} else {
|
||||
size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
|
||||
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
|
||||
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
|
||||
num_blocks.y = out_dims.x;
|
||||
num_blocks.z = out_dims.y;
|
||||
block_dims.x = BLOCK_DIM_X;
|
||||
kernel = cu::row_reduce_looped<
|
||||
InType,
|
||||
OutType,
|
||||
OP,
|
||||
NDIM,
|
||||
BLOCK_DIM_X,
|
||||
N_READS>;
|
||||
});
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(), out.data<OutType>(), out.size(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
@ -0,0 +1,84 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <cub/device/device_reduce.cuh>
|
||||
#include <cub/device/device_segmented_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename... Args>
|
||||
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(
|
||||
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
struct MultiplyOp {
|
||||
int factor;
|
||||
__device__ int operator()(int i) {
|
||||
return i * factor;
|
||||
}
|
||||
};
|
||||
|
||||
void segmented_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
auto in_iter = cu::make_cast_iterator<OutType>(
|
||||
thrust::device_pointer_cast(in.data<InType>()));
|
||||
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||
auto init = cu::ReduceInit<OP, InType>::value();
|
||||
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
cub_all_reduce(
|
||||
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
|
||||
} else if (plan.type == ContiguousReduce) {
|
||||
auto offsets = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
|
||||
cub_segmented_reduce(
|
||||
encoder,
|
||||
in_iter,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
offsets,
|
||||
offsets + 1,
|
||||
OP(),
|
||||
init,
|
||||
stream);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported plan in segmented_reduce.");
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
343
mlx/backend/cuda/rms_norm.cu
Normal file
343
mlx/backend/cuda/rms_norm.cu
Normal file
@ -0,0 +1,343 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
inline __device__ float2 plus_f2(const float2& a, const float2& b) {
|
||||
return {a.x + b.x, a.y + b.y};
|
||||
}
|
||||
|
||||
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
||||
template <typename T, int BLOCK_DIM>
|
||||
struct BlockBroadcastReduce {
|
||||
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
||||
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
||||
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
||||
|
||||
cg::thread_block& block;
|
||||
TempStorage& temp;
|
||||
|
||||
template <typename Op>
|
||||
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
T x = cg::reduce(warp, input, op);
|
||||
if (warp.thread_rank() == 0) {
|
||||
temp[warp.meta_group_rank()] = x;
|
||||
}
|
||||
block.sync();
|
||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||
: init_value;
|
||||
return cg::reduce(warp, x, op);
|
||||
}
|
||||
|
||||
__device__ T Sum(const T& input) {
|
||||
return Reduce(input, cg::plus<T>{}, T{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm(
|
||||
const T* x,
|
||||
const T* w,
|
||||
T* out,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]);
|
||||
normalizer += t * t;
|
||||
}
|
||||
}
|
||||
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float norm = static_cast<float>(xn[i]) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(norm);
|
||||
}
|
||||
cub::StoreDirectBlocked(index, out, xn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm_vjp(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* g,
|
||||
T* gx,
|
||||
T* gw,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
|
||||
__shared__ union {
|
||||
typename BlockReduceF::TempStorage f;
|
||||
typename BlockReduceF2::TempStorage f2;
|
||||
} temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
g += grid.block_rank() * axis_size;
|
||||
gx += grid.block_rank() * axis_size;
|
||||
gw += grid.block_rank() * axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float2 factors = {};
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T xn[N_READS];
|
||||
T wn[N_READS] = {};
|
||||
T gn[N_READS] = {};
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = static_cast<float>(xn[i]);
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f2(factors, {wg * t, t * t});
|
||||
}
|
||||
}
|
||||
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
|
||||
float meangwx = factors.x / axis_size;
|
||||
float normalizer = rsqrt(factors.y / axis_size + eps);
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T gn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = xn[i];
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
|
||||
if constexpr (HAS_W) {
|
||||
wn[i] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
cub::StoreDirectBlocked(index, gx, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
cub::StoreDirectBlocked(index, gw, wn, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool RMSNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||
void RMSNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("RMSNorm::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
const array x = set_output(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rms_norm", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::rms_norm<DataType, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void RMSNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("RMSNormVJP::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||
if (x.flags().row_contiguous) {
|
||||
return {x, false};
|
||||
}
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
return {x_copy, true};
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[2].is_donatable();
|
||||
auto [x, copied] = check_input(inputs[0]);
|
||||
donate_x |= copied;
|
||||
const array& w = inputs[1];
|
||||
auto [g, g_copied] = check_input(inputs[2]);
|
||||
donate_g |= g_copied;
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
|
||||
// Check whether we had a weight.
|
||||
bool has_w = w.ndim() != 0;
|
||||
|
||||
// Allocate space for the outputs.
|
||||
bool g_in_gx = false;
|
||||
if (donate_x) {
|
||||
gx.copy_shared_buffer(x);
|
||||
} else if (donate_g) {
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
encoder.add_temporary(g);
|
||||
}
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and allocate the output
|
||||
// gradient accumulators.
|
||||
array gw_temp =
|
||||
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||
if (has_w) {
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(g);
|
||||
encoder.set_output_array(gx);
|
||||
encoder.set_output_array(gw_temp);
|
||||
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rms_norm_vjp", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::rms_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
if (has_w) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
385
mlx/backend/cuda/rope.cu
Normal file
385
mlx/backend/cuda/rope.cu
Normal file
@ -0,0 +1,385 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__device__ void rope_single_impl(
|
||||
const T* in,
|
||||
T* out,
|
||||
int32_t offset,
|
||||
float inv_freq,
|
||||
float scale,
|
||||
int64_t stride,
|
||||
uint2 pos,
|
||||
uint2 dims) {
|
||||
float L = scale * static_cast<float>(offset);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
float costheta = cos(theta);
|
||||
float sintheta = sin(theta);
|
||||
|
||||
// Compute the input and output indices
|
||||
uint index_1, index_2;
|
||||
if (traditional) {
|
||||
index_1 = 2 * pos.x + pos.y * stride;
|
||||
index_2 = index_1 + 1;
|
||||
} else {
|
||||
index_1 = pos.x + pos.y * stride;
|
||||
index_2 = index_1 + dims.x;
|
||||
}
|
||||
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[index_1]);
|
||||
float x2 = static_cast<float>(in[index_2]);
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[index_1] = static_cast<T>(rx1);
|
||||
out[index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope_single(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
float scale,
|
||||
float base,
|
||||
int64_t stride,
|
||||
uint2 dims) {
|
||||
uint2 pos = make_uint2(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
|
||||
float inv_freq = exp2(-d * base);
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, *offset, inv_freq, scale, stride, pos, dims);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope_single_freqs(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
const float* freqs,
|
||||
float scale,
|
||||
int64_t stride,
|
||||
uint2 dims,
|
||||
int64_t freq_stride) {
|
||||
uint2 pos = make_uint2(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, *offset, inv_freq, scale, stride, pos, dims);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
__device__ void rope_impl(
|
||||
const T* in,
|
||||
T* out,
|
||||
int offset,
|
||||
float inv_freq,
|
||||
float scale,
|
||||
const cuda::std::array<int64_t, 3> strides,
|
||||
const cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
uint3 pos,
|
||||
uint3 dims) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
float costheta = cos(theta);
|
||||
float sintheta = sin(theta);
|
||||
|
||||
// Compute the input and output indices
|
||||
size_t in_index_1, in_index_2;
|
||||
size_t out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + dims.x * out_strides[2];
|
||||
in_index_1 =
|
||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + dims.x * strides[2];
|
||||
}
|
||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
in_index_1 += strides[0];
|
||||
in_index_2 += strides[0];
|
||||
out_index_1 += out_strides[0];
|
||||
out_index_2 += out_strides[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
float scale,
|
||||
float base,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
uint3 dims) {
|
||||
uint3 pos = make_uint3(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y,
|
||||
blockIdx.z * blockDim.z + threadIdx.z);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
|
||||
float inv_freq = exp2(-d * base);
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope_freqs(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
const float* freqs,
|
||||
float scale,
|
||||
float base,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
uint3 dims,
|
||||
int64_t freq_stride) {
|
||||
uint3 pos = make_uint3(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y,
|
||||
blockIdx.z * blockDim.z + threadIdx.z);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool RoPE::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
void RoPE::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("RoPE::eval_gpu");
|
||||
|
||||
auto& s = stream();
|
||||
auto& in = inputs[0];
|
||||
auto& offset = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() < 3) {
|
||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||
}
|
||||
|
||||
cuda::std::array<int64_t, 3> strides;
|
||||
cuda::std::array<int64_t, 3> out_strides;
|
||||
bool donated = false;
|
||||
int ndim = in.ndim();
|
||||
int dispatch_ndim = in.ndim();
|
||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||
dispatch_ndim--;
|
||||
}
|
||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||
|
||||
// We apply rope to less that the whole vector so copy to output and then
|
||||
// apply in-place.
|
||||
if (dims_ < in.shape(-1)) {
|
||||
donated = true;
|
||||
auto ctype =
|
||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||
copy_gpu(in, out, ctype, s);
|
||||
strides[0] = mat_size;
|
||||
strides[1] = out.strides()[ndim - 2];
|
||||
strides[2] = out.strides()[ndim - 1];
|
||||
}
|
||||
|
||||
// Either copy or apply in-place
|
||||
else if (in.flags().row_contiguous) {
|
||||
if (in.is_donatable()) {
|
||||
donated = true;
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
strides[0] = mat_size;
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else if (dispatch_ndim == 3) {
|
||||
// Handle non-contiguous 3D inputs
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
strides[0] = in.strides()[ndim - 3];
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else {
|
||||
// Copy non-contiguous > 3D inputs into the output and treat
|
||||
// input as donated
|
||||
donated = true;
|
||||
copy_gpu(in, out, CopyType::General, s);
|
||||
strides[0] = mat_size;
|
||||
strides[1] = out.strides()[ndim - 2];
|
||||
strides[2] = out.strides()[ndim - 1];
|
||||
}
|
||||
out_strides[0] = mat_size;
|
||||
out_strides[1] = out.strides()[ndim - 2];
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
// Some flags to help us dispatch below
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(donated ? out : in);
|
||||
encoder.set_input_array(offset);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
MLX_SWITCH_BOOL(traditional_, TRADITIONAL, {
|
||||
MLX_SWITCH_BOOL(forward_, FORWARD, {
|
||||
if (single && !with_freqs) {
|
||||
auto kernel = cu::rope_single<DataType, TRADITIONAL, FORWARD>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
mat_size,
|
||||
dims);
|
||||
} else if (single) {
|
||||
auto kernel = cu::rope_single_freqs<DataType, TRADITIONAL, FORWARD>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
scale_,
|
||||
mat_size,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else if (with_freqs) {
|
||||
auto kernel = cu::rope_freqs<DataType, TRADITIONAL, FORWARD>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else {
|
||||
auto kernel = cu::rope<DataType, TRADITIONAL, FORWARD>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
dims);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
@ -1,7 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void concatenate_gpu(
|
||||
@ -9,7 +13,29 @@ void concatenate_gpu(
|
||||
array& out,
|
||||
int axis,
|
||||
const Stream& s) {
|
||||
throw std::runtime_error("concatenate_gpu not implemented in CUDA backend.");
|
||||
std::vector<int> sizes;
|
||||
sizes.push_back(0);
|
||||
for (auto& p : inputs) {
|
||||
sizes.push_back(p.shape(axis));
|
||||
}
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto strides = out.strides();
|
||||
auto flags = out.flags();
|
||||
flags.row_contiguous = false;
|
||||
flags.col_contiguous = false;
|
||||
flags.contiguous = false;
|
||||
// TODO: Handle concurrent outputs:
|
||||
// https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
out, strides, flags, out_slice.size(), data_offset);
|
||||
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
160
mlx/backend/cuda/softmax.cu
Normal file
160
mlx/backend/cuda/softmax.cu
Normal file
@ -0,0 +1,160 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||
return __expf(x);
|
||||
}
|
||||
|
||||
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void softmax(const T* in, T* out, int axis_size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
in += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
cg::greater<AccT> max_op;
|
||||
cg::plus<AccT> plus_op;
|
||||
|
||||
// Thread reduce.
|
||||
AccT prevmax;
|
||||
AccT maxval = Limits<AccT>::finite_min();
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
AccT vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r * BLOCK_DIM + block.thread_rank(),
|
||||
make_cast_iterator<AccT>(in),
|
||||
vals,
|
||||
axis_size,
|
||||
Limits<AccT>::finite_min());
|
||||
prevmax = maxval;
|
||||
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||
// Online normalizer calculation for softmax:
|
||||
// https://github.com/NVIDIA/online-softmax
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
||||
}
|
||||
}
|
||||
|
||||
// First warp reduce.
|
||||
prevmax = maxval;
|
||||
maxval = cg::reduce(warp, maxval, max_op);
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||
|
||||
__shared__ AccT local_max[WARP_SIZE];
|
||||
__shared__ AccT local_normalizer[WARP_SIZE];
|
||||
|
||||
// Write to shared memory and do second warp reduce.
|
||||
prevmax = maxval;
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_max[warp.meta_group_rank()] = maxval;
|
||||
}
|
||||
block.sync();
|
||||
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_max[warp.thread_rank()]
|
||||
: Limits<AccT>::finite_min();
|
||||
maxval = cg::reduce(warp, maxval, max_op);
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_normalizer[warp.meta_group_rank()] = normalizer;
|
||||
}
|
||||
block.sync();
|
||||
normalizer = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_normalizer[warp.thread_rank()]
|
||||
: AccT{};
|
||||
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Write output.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(index, in, vals, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;
|
||||
}
|
||||
cub::StoreDirectBlocked(index, out, vals, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Softmax::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& s = stream();
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
array in = set_output(inputs[0]);
|
||||
bool precise = in.dtype() != float32 && precise_;
|
||||
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::softmax<DataType, DataType, BLOCK_DIM, N_READS>;
|
||||
if (precise) {
|
||||
kernel = cu::softmax<DataType, float, BLOCK_DIM, N_READS>;
|
||||
}
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
192
mlx/backend/cuda/sort.cu
Normal file
192
mlx/backend/cuda/sort.cu
Normal file
@ -0,0 +1,192 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/transform.h>
|
||||
#include <cub/device/device_segmented_sort.cuh>
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
struct ModOp {
|
||||
T divisor;
|
||||
__device__ T operator()(T x) {
|
||||
return x % divisor;
|
||||
}
|
||||
};
|
||||
|
||||
// We can not use any op in eval, make an utility.
|
||||
array swapaxes_in_eval(const array& in, int axis1, int axis2) {
|
||||
std::vector<int> axes(in.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
std::swap(axes[axis1], axes[axis2]);
|
||||
// TODO: Share the code with Transpose::eval.
|
||||
Shape shape(axes.size());
|
||||
Strides strides(in.ndim());
|
||||
for (size_t ax = 0; ax < axes.size(); ++ax) {
|
||||
shape[ax] = in.shape()[axes[ax]];
|
||||
strides[ax] = in.strides()[axes[ax]];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous) {
|
||||
auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides);
|
||||
flags.row_contiguous = row_contiguous;
|
||||
flags.col_contiguous = col_contiguous;
|
||||
}
|
||||
array out(shape, in.dtype(), nullptr, {});
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void segmented_sort_pairs(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(
|
||||
cub::DeviceSegmentedSort::StableSortPairs(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
|
||||
temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(
|
||||
cub::DeviceSegmentedSort::StableSortKeys(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
|
||||
temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
array out = out_;
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
if (axis < 0) {
|
||||
axis += in.ndim();
|
||||
}
|
||||
int nsort = in.shape(axis);
|
||||
int last_dim = in.ndim() - 1;
|
||||
|
||||
// If we are not sorting the innermost dimension of a contiguous array,
|
||||
// transpose and make a copy.
|
||||
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
|
||||
if (!is_segmented_sort) {
|
||||
array trans = swapaxes_in_eval(in, axis, last_dim);
|
||||
in = array(trans.shape(), trans.dtype(), nullptr, {});
|
||||
copy_gpu(trans, in, CopyType::General, s);
|
||||
encoder.add_temporary(in);
|
||||
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||
encoder.add_temporary(out);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||
using Type = cuda_type_t<CTYPE>;
|
||||
auto offsets = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0),
|
||||
[nsort] __device__(int i) { return i * nsort; });
|
||||
if (argsort) {
|
||||
// Indices in the sorted dimension.
|
||||
array indices(
|
||||
allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||
encoder.add_temporary(indices);
|
||||
thrust::transform(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::counting_iterator<uint32_t>(0),
|
||||
thrust::counting_iterator<uint32_t>(indices.data_size()),
|
||||
thrust::device_pointer_cast(indices.data<uint32_t>()),
|
||||
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
|
||||
|
||||
// In argsort though we don't need the result of sorted values, the
|
||||
// API requires us to provide an array to store it.
|
||||
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
|
||||
encoder.add_temporary(discard);
|
||||
|
||||
segmented_sort_pairs(
|
||||
encoder,
|
||||
in.data<Type>(),
|
||||
discard.data<Type>(),
|
||||
indices.data<uint32_t>(),
|
||||
out.data<uint32_t>(),
|
||||
in.data_size(),
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
} else {
|
||||
segmented_sort(
|
||||
encoder,
|
||||
in.data<Type>(),
|
||||
out.data<Type>(),
|
||||
in.data_size(),
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"CUDA backend does not support sorting complex numbers");
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
if (!is_segmented_sort) {
|
||||
// Swap the sorted axis back.
|
||||
// TODO: Do in-place transpose instead of using a temporary out array.
|
||||
copy_gpu(swapaxes_in_eval(out, axis, last_dim), out_, CopyType::General, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ArgSort::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
gpu_sort(stream(), inputs[0], out, axis_, true);
|
||||
}
|
||||
|
||||
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Sort::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||
}
|
||||
|
||||
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ArgPartition::eval_gpu");
|
||||
gpu_sort(stream(), inputs[0], out, axis_, true);
|
||||
}
|
||||
|
||||
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Partition::eval_gpu");
|
||||
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
178
mlx/backend/cuda/ternary.cu
Normal file
178
mlx/backend/cuda/ternary.cu
Normal file
@ -0,0 +1,178 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename T, typename IdxT>
|
||||
__global__ void
|
||||
ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[index], b[index], c[index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename T, typename IdxT, int NDIM>
|
||||
__global__ void ternary_g_nd(
|
||||
const bool* a,
|
||||
const T* b,
|
||||
const T* c,
|
||||
T* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_strides) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx, c_idx] = elem_to_loc_nd<NDIM>(
|
||||
index,
|
||||
shape.data(),
|
||||
a_strides.data(),
|
||||
b_strides.data(),
|
||||
c_strides.data());
|
||||
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename T, typename IdxT>
|
||||
__global__ void ternary_g(
|
||||
const bool* a,
|
||||
const T* b,
|
||||
const T* c,
|
||||
T* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides a_strides,
|
||||
const __grid_constant__ Strides b_strides,
|
||||
const __grid_constant__ Strides c_strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx, c_idx] = elem_to_loc_4d(
|
||||
index,
|
||||
shape.data(),
|
||||
a_strides.data(),
|
||||
b_strides.data(),
|
||||
c_strides.data(),
|
||||
ndim);
|
||||
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename Op>
|
||||
void ternary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const Stream& s) {
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
const auto& c = inputs[2];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, {
|
||||
using DType = cuda_type_t<CTYPE>;
|
||||
|
||||
auto topt = get_ternary_op_type(a, b, c);
|
||||
if (topt == TernaryOpType::General) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
auto& c_strides = strides[2];
|
||||
bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel = cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(a_strides),
|
||||
const_param<NDIM>(b_strides),
|
||||
const_param<NDIM>(c_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::ternary_g<Op, DType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
const_param(c_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size());
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void ternary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto& c = inputs[2];
|
||||
auto topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt);
|
||||
ternary_op_gpu_inplace<Op>(inputs, out, s);
|
||||
}
|
||||
|
||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("select::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
ternary_op_gpu<cu::Select>(inputs, out, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
198
mlx/backend/cuda/unary.cu
Normal file
198
mlx/backend/cuda/unary.cu
Normal file
@ -0,0 +1,198 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/transform.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_unary_op() {
|
||||
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
|
||||
std::is_same_v<Op, Sign>) {
|
||||
return std::is_same_v<In, Out>;
|
||||
}
|
||||
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcCosh> ||
|
||||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
|
||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||
!std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor> ||
|
||||
std::is_same_v<Op, Square>) {
|
||||
return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>;
|
||||
}
|
||||
if (std::is_same_v<Op, Conjugate>) {
|
||||
return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>;
|
||||
}
|
||||
if (std::is_same_v<Op, Cos> || std::is_same_v<Op, Cosh> ||
|
||||
std::is_same_v<Op, Exp> || std::is_same_v<Op, Round> ||
|
||||
std::is_same_v<Op, Sin> || std::is_same_v<Op, Sinh> ||
|
||||
std::is_same_v<Op, Tan> || std::is_same_v<Op, Tanh>) {
|
||||
return std::is_same_v<In, Out> &&
|
||||
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
|
||||
}
|
||||
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
|
||||
return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogicalNot>) {
|
||||
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename Op>
|
||||
void unary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
auto& in = inputs[0];
|
||||
if (in.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
||||
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
auto policy = cu::thrust_policy(stream);
|
||||
auto in_ptr = thrust::device_pointer_cast(in.data<InType>());
|
||||
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||
if (in.flags().contiguous) {
|
||||
thrust::transform(
|
||||
policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op());
|
||||
} else {
|
||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
|
||||
in_ptr, in.size(), shape, strides);
|
||||
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do unary op {} on input of {} with output of {}.",
|
||||
op,
|
||||
dtype_to_string(in.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
set_unary_output_data(inputs[0], out);
|
||||
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
#define UNARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = out.primitive().stream(); \
|
||||
unary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
UNARY_GPU(Abs)
|
||||
UNARY_GPU(ArcCos)
|
||||
UNARY_GPU(ArcCosh)
|
||||
UNARY_GPU(ArcSin)
|
||||
UNARY_GPU(ArcSinh)
|
||||
UNARY_GPU(ArcTan)
|
||||
UNARY_GPU(ArcTanh)
|
||||
UNARY_GPU(BitwiseInvert)
|
||||
UNARY_GPU(Ceil)
|
||||
UNARY_GPU(Conjugate)
|
||||
UNARY_GPU(Cos)
|
||||
UNARY_GPU(Cosh)
|
||||
UNARY_GPU(Erf)
|
||||
UNARY_GPU(ErfInv)
|
||||
UNARY_GPU(Exp)
|
||||
UNARY_GPU(Expm1)
|
||||
UNARY_GPU(Floor)
|
||||
UNARY_GPU(Imag)
|
||||
UNARY_GPU(Log1p)
|
||||
UNARY_GPU(LogicalNot)
|
||||
UNARY_GPU(Negative)
|
||||
UNARY_GPU(Real)
|
||||
UNARY_GPU(Sigmoid)
|
||||
UNARY_GPU(Sign)
|
||||
UNARY_GPU(Sin)
|
||||
UNARY_GPU(Sinh)
|
||||
UNARY_GPU(Square)
|
||||
UNARY_GPU(Tan)
|
||||
UNARY_GPU(Tanh)
|
||||
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Log::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_op_gpu<cu::Log>(inputs, out, op, s);
|
||||
break;
|
||||
case Base::two:
|
||||
unary_op_gpu<cu::Log2>(inputs, out, op, s);
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_op_gpu<cu::Log10>(inputs, out, op, s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Round::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
auto& s = out.primitive().stream();
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_op_gpu<cu::Round>(inputs, out, get_primitive_string(this), s);
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Sort::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
if (recip_) {
|
||||
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);
|
||||
} else {
|
||||
unary_op_gpu<cu::Sqrt>(inputs, out, "Sqrt", s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
@ -23,4 +24,23 @@ void check_cuda_error(const char* name, cudaError_t err) {
|
||||
}
|
||||
}
|
||||
|
||||
const char* dtype_to_cuda_type(const Dtype& dtype) {
|
||||
if (dtype == float16) {
|
||||
return "__half";
|
||||
}
|
||||
if (dtype == bfloat16) {
|
||||
return "__nv_bfloat16";
|
||||
}
|
||||
if (dtype == complex64) {
|
||||
return "cuComplex";
|
||||
}
|
||||
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
|
||||
if (dtype == DTYPE) { \
|
||||
return #CPP_TYPE; \
|
||||
}
|
||||
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
|
||||
#undef SPECIALIZE_DtypeToString
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -12,6 +12,8 @@ namespace cu {
|
||||
class Device;
|
||||
}
|
||||
|
||||
struct Dtype;
|
||||
|
||||
// Cuda stream managed with RAII.
|
||||
class CudaStream {
|
||||
public:
|
||||
@ -35,4 +37,7 @@ void check_cuda_error(const char* name, cudaError_t err);
|
||||
// The macro version that prints the command that failed.
|
||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||
|
||||
// Convert Dtype to CUDA C++ types.
|
||||
const char* dtype_to_cuda_type(const Dtype& dtype);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -80,7 +80,9 @@ void Worker::thread_fn() {
|
||||
}
|
||||
worker_tasks_.erase(worker_tasks_.begin(), end);
|
||||
}
|
||||
for (auto& task : tasks) {
|
||||
// Make sure tasks are cleared before the next wait
|
||||
for (int i = 0; i < tasks.size(); ++i) {
|
||||
auto task = std::move(tasks[i]);
|
||||
task();
|
||||
}
|
||||
worker_event_.wait(batch + 1);
|
||||
|
@ -1,13 +1,22 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
#if defined(MLX_USE_CUDA)
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#endif
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#if defined(MLX_USE_CUDA)
|
||||
#define MLX_PROFILER_RANGE(message) nvtx3::scoped_range r(message)
|
||||
#else
|
||||
#define MLX_PROFILER_RANGE(message)
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -162,6 +171,41 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
slice_gpu(in, out, start_indices_, strides_, stream());
|
||||
}
|
||||
|
||||
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& upd = inputs[1];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
auto [data_offset, out_strides] =
|
||||
prepare_slice(out, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ upd.shape(),
|
||||
/* const Strides& i_strides = */ upd.strides(),
|
||||
/* const Strides& o_strides = */ out_strides,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ data_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ stream());
|
||||
}
|
||||
|
||||
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Squeeze::eval_gpu");
|
||||
eval(inputs, out);
|
||||
|
@ -31,13 +31,13 @@ std::string get_kernel_name(
|
||||
kname = "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname = (large ? "sv2" : "sv");
|
||||
kname = "sv";
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname = (large ? "vs2" : "vs");
|
||||
kname = "vs";
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname = (large ? "vv2" : "vv");
|
||||
kname = "vv";
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname = "g";
|
||||
@ -51,6 +51,13 @@ std::string get_kernel_name(
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) {
|
||||
if (large) {
|
||||
kname += "2";
|
||||
} else if (work_per_thread > 1) {
|
||||
kname += "n";
|
||||
}
|
||||
}
|
||||
concatenate(kname, "_", op, type_to_name(a));
|
||||
return kname;
|
||||
}
|
||||
@ -90,7 +97,7 @@ void binary_op_gpu_inplace(
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
work_per_thread = get_work_per_thread(a.dtype());
|
||||
work_per_thread = get_work_per_thread(a.dtype(), out.data_size());
|
||||
}
|
||||
std::string kernel_name =
|
||||
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
|
||||
|
@ -11,8 +11,6 @@
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
using namespace fmt::literals;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline void build_kernel(
|
||||
@ -21,21 +19,12 @@ inline void build_kernel(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous,
|
||||
int ndim,
|
||||
bool dynamic_dims,
|
||||
bool use_big_index = false,
|
||||
int work_per_thread = 1) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
|
||||
// Constants are scalars that are captured by value and cannot change
|
||||
auto is_constant = [&constant_ids](const array& x) {
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
NodeNamer namer;
|
||||
bool add_indices = false;
|
||||
int cnt = 0;
|
||||
@ -45,14 +34,15 @@ inline void build_kernel(
|
||||
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
|
||||
|
||||
// Add the input arguments
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants from the input list
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
// Scalars and contiguous need no strides
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
add_indices = true;
|
||||
@ -80,8 +70,6 @@ inline void build_kernel(
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os += fmt::format(
|
||||
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||
os += fmt::format(
|
||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||
} else {
|
||||
@ -125,7 +113,7 @@ inline void build_kernel(
|
||||
auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
auto type_str = get_type_string(x.dtype());
|
||||
std::ostringstream ss;
|
||||
print_constant(ss, x);
|
||||
@ -271,11 +259,6 @@ inline void build_kernel(
|
||||
void Compiled::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Make the name for the kernel library
|
||||
if (kernel_lib_.empty()) {
|
||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||
}
|
||||
|
||||
// Get the kernel if someone else built it already
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@ -290,19 +273,33 @@ void Compiled::eval_gpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ work_per_thread);
|
||||
/* work_per_thread = */ 1);
|
||||
if (work_per_thread > 1) {
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous_n",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
is_constant_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ work_per_thread);
|
||||
}
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous_large",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false,
|
||||
@ -315,7 +312,7 @@ void Compiled::eval_gpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false,
|
||||
@ -328,7 +325,7 @@ void Compiled::eval_gpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false,
|
||||
@ -342,7 +339,7 @@ void Compiled::eval_gpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true,
|
||||
@ -354,7 +351,7 @@ void Compiled::eval_gpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true,
|
||||
@ -363,81 +360,32 @@ void Compiled::eval_gpu(
|
||||
return kernel;
|
||||
});
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& output_shape = outputs[0].shape();
|
||||
auto contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
std::vector<Strides> initial_strides;
|
||||
initial_strides.push_back(outputs[0].strides());
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
if (!contiguous) {
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
auto [contiguous, shape, strides] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Skip scalar inputs.
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
Strides xstrides;
|
||||
int j = 0;
|
||||
for (; j < output_shape.size() - x.ndim(); j++) {
|
||||
if (output_shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (output_shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
initial_strides.push_back(std::move(xstrides));
|
||||
}
|
||||
std::tie(shape, strides) =
|
||||
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
|
||||
}
|
||||
|
||||
bool large;
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
for (auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
large = (max_size > UINT32_MAX);
|
||||
} else {
|
||||
size_t max_size = 0;
|
||||
for (auto& o : outputs) {
|
||||
max_size = std::max(max_size, o.size());
|
||||
}
|
||||
large = (max_size > UINT32_MAX);
|
||||
}
|
||||
// Whether to use large index.
|
||||
bool large = compiled_use_large_index(inputs, outputs, contiguous);
|
||||
|
||||
// Get the kernel from the lib
|
||||
int ndim = shape.size();
|
||||
bool dynamic = ndim >= 8;
|
||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||
int work_per_thread = 1;
|
||||
if (!contiguous) {
|
||||
if (dynamic) {
|
||||
kernel_name += "dynamic";
|
||||
} else {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
}
|
||||
work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
|
||||
} else {
|
||||
work_per_thread =
|
||||
get_work_per_thread(outputs[0].dtype(), outputs[0].data_size());
|
||||
if (work_per_thread > 1 && !large) {
|
||||
kernel_name += "_n";
|
||||
}
|
||||
}
|
||||
if (large) {
|
||||
kernel_name += "_large";
|
||||
@ -451,7 +399,7 @@ void Compiled::eval_gpu(
|
||||
int stride_idx = 1; // idx 0 is the output strides
|
||||
Strides in_strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
@ -468,8 +416,7 @@ void Compiled::eval_gpu(
|
||||
compute_encoder.set_vector_bytes(in_strides, cnt++);
|
||||
}
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
|
||||
// Put the outputs in
|
||||
for (auto& x : outputs) {
|
||||
@ -478,7 +425,6 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Put the output shape and strides in
|
||||
if (!contiguous) {
|
||||
compute_encoder.set_vector_bytes(strides[0], cnt++);
|
||||
compute_encoder.set_vector_bytes(shape, cnt++);
|
||||
} else {
|
||||
auto size = outputs[0].data_size();
|
||||
@ -496,7 +442,6 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Launch the kernel
|
||||
if (contiguous) {
|
||||
int work_per_thread = get_work_per_thread(outputs[0].dtype());
|
||||
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
@ -509,7 +454,6 @@ void Compiled::eval_gpu(
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = outputs[0].size() / (dim0 * dim1);
|
||||
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
int pow2;
|
||||
|
@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_unfolded, wt_transpose};
|
||||
return steel_matmul_regular(
|
||||
s,
|
||||
d,
|
||||
/* a = */ in_unfolded,
|
||||
/* b = */ wt_transpose,
|
||||
/* c = */ out,
|
||||
/* M = */ implicit_M,
|
||||
/* N = */ implicit_N,
|
||||
/* K = */ implicit_K,
|
||||
/* batch_size_out = */ groups,
|
||||
/* a_cols = */ implicit_K * groups,
|
||||
/* b_cols = */ implicit_K,
|
||||
/* out_cols = */ implicit_N * groups,
|
||||
/* a_transposed = */ false,
|
||||
/* b_transposed = */ true,
|
||||
/* batch_shape = */ {1},
|
||||
/* batch_strides = */ {0},
|
||||
/* A_batch_strides = */ size_t(implicit_K),
|
||||
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
|
||||
/* matrix_stride_out = */ size_t(implicit_N),
|
||||
/*copies = */ copies);
|
||||
/* const Stream& s = */ s,
|
||||
/* Device& d = */ d,
|
||||
/* const array& a = */ in_unfolded,
|
||||
/* const array& b = */ wt_transpose,
|
||||
/* array& c = */ out,
|
||||
/* int M = */ implicit_M,
|
||||
/* int N = */ implicit_N,
|
||||
/* int K = */ implicit_K,
|
||||
/* int batch_size_out = */ groups,
|
||||
/* int lda = */ implicit_K * groups,
|
||||
/* int ldb = */ implicit_K,
|
||||
/* int ldd = */ implicit_N * groups,
|
||||
/* bool transpose_a = */ false,
|
||||
/* bool transpose_b = */ true,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ {1},
|
||||
/* Strides batch_strides = */ {0},
|
||||
/* int64_t A_batch_strides = */ int64_t(implicit_K),
|
||||
/* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
|
||||
/* int64_t matrix_stride_out = */ int64_t(implicit_N));
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
@ -391,6 +391,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
// Get channel iteration info
|
||||
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
||||
int gemm_k_iters = channel_k_iters;
|
||||
bool align_C = conv_params.C % bk == 0;
|
||||
|
||||
// Fix host side helper params
|
||||
int sign = (conv_params.flip ? -1 : 1);
|
||||
@ -419,14 +420,33 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
/* const int swizzle_log = */ swizzle_log};
|
||||
|
||||
// Determine kernel
|
||||
std::ostringstream kname;
|
||||
kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm
|
||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
concatenate(
|
||||
kname,
|
||||
"implicit_gemm_conv_2d_general_",
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn);
|
||||
std::string hash_name;
|
||||
hash_name.reserve(64);
|
||||
concatenate(hash_name, kname, "_alC_", align_C);
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_C, MTL::DataType::DataTypeBool, 200},
|
||||
};
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel =
|
||||
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
|
||||
auto kernel = get_steel_conv_general_kernel(
|
||||
d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
@ -677,7 +697,7 @@ void depthwise_conv_2D_gpu(
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@ -728,8 +748,10 @@ void dispatch_conv_2D_gpu(
|
||||
|
||||
// Direct to winograd conv
|
||||
bool inp_large =
|
||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
|
||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096;
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
||||
bool out_large =
|
||||
(conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256;
|
||||
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
||||
@ -743,7 +765,7 @@ void dispatch_conv_2D_gpu(
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
|
||||
else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) {
|
||||
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
|
@ -55,10 +55,10 @@ void copy_gpu_inplace(
|
||||
std::string kernel_name;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kernel_name = (large ? "s2" : "s");
|
||||
kernel_name = large ? "s2" : "s";
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kernel_name = (large ? "v2" : "v");
|
||||
kernel_name = large ? "v2" : "v";
|
||||
break;
|
||||
case CopyType::General:
|
||||
kernel_name = "g";
|
||||
@ -85,7 +85,10 @@ void copy_gpu_inplace(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
work_per_thread = get_work_per_thread(in.dtype());
|
||||
work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
|
||||
if (work_per_thread > 1) {
|
||||
kernel_name += "n";
|
||||
}
|
||||
}
|
||||
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
||||
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
|
||||
@ -170,9 +173,10 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
bool large = out.data_size() > UINT32_MAX;
|
||||
int work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
|
||||
auto& d = metal::device(s.device);
|
||||
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
|
||||
type_to_name(val) + type_to_name(out);
|
||||
std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s");
|
||||
concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out));
|
||||
auto kernel = get_copy_kernel(d, kernel_name, val, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
@ -180,7 +184,6 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
compute_encoder.set_input_array(val, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
int work_per_thread = get_work_per_thread(val.dtype());
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
|
@ -1,12 +1,326 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
#include <regex>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
struct CustomKernelCache {
|
||||
std::unordered_map<std::string, std::string> libraries;
|
||||
};
|
||||
|
||||
static CustomKernelCache& cache() {
|
||||
static CustomKernelCache cache_;
|
||||
return cache_;
|
||||
};
|
||||
|
||||
std::string write_signature(
|
||||
std::string func_name,
|
||||
const std::string& header,
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<std::string>& attributes,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
bool atomic_outputs) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 16384);
|
||||
kernel_source += header;
|
||||
// Auto-generate a function signature based on `template_args`
|
||||
// and the dtype/shape of the arrays passed as `inputs`.
|
||||
if (!template_args.empty()) {
|
||||
kernel_source += "template <";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
std::string param_type;
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
param_type = "int";
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
param_type = "bool";
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
param_type = "typename";
|
||||
}
|
||||
if (i > 0) {
|
||||
kernel_source += ", ";
|
||||
}
|
||||
kernel_source += param_type;
|
||||
kernel_source += " ";
|
||||
kernel_source += name;
|
||||
i++;
|
||||
}
|
||||
kernel_source += ">\n";
|
||||
}
|
||||
kernel_source += "[[kernel]] void ";
|
||||
kernel_source += func_name;
|
||||
kernel_source += "(\n";
|
||||
|
||||
int index = 0;
|
||||
constexpr int max_constant_array_size = 8;
|
||||
// Add inputs
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
const auto& name = input_names[i];
|
||||
const auto& arr = inputs[i];
|
||||
auto dtype = get_type_string(arr.dtype());
|
||||
std::string location =
|
||||
arr.size() < max_constant_array_size ? "constant" : "device";
|
||||
std::string ref = arr.ndim() == 0 ? "&" : "*";
|
||||
kernel_source += " const ";
|
||||
kernel_source += location;
|
||||
kernel_source += " ";
|
||||
kernel_source += dtype;
|
||||
kernel_source += ref;
|
||||
kernel_source += " ";
|
||||
kernel_source += name;
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]],\n";
|
||||
index++;
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
kernel_source +=
|
||||
(" const constant int* " + name + "_shape [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
kernel_source +=
|
||||
(" const constant int64_t* " + name + "_strides [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
kernel_source +=
|
||||
(" const constant int& " + name + "_ndim [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add outputs
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
const auto& name = output_names[i];
|
||||
const auto& dtype = output_dtypes[i];
|
||||
kernel_source += " device ";
|
||||
auto type_string = get_type_string(dtype);
|
||||
if (atomic_outputs) {
|
||||
kernel_source += "atomic<";
|
||||
}
|
||||
kernel_source += type_string;
|
||||
if (atomic_outputs) {
|
||||
kernel_source += ">";
|
||||
}
|
||||
kernel_source += "* ";
|
||||
kernel_source += name;
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]]";
|
||||
if (index < inputs.size() + output_names.size() - 1 ||
|
||||
attributes.size() > 0) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
index++;
|
||||
}
|
||||
|
||||
index = 0;
|
||||
for (const auto& attr : attributes) {
|
||||
kernel_source += attr;
|
||||
if (index < attributes.size() - 1) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
index++;
|
||||
}
|
||||
kernel_source += source;
|
||||
kernel_source += "\n}\n";
|
||||
return kernel_source;
|
||||
}
|
||||
|
||||
std::string write_template(
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
|
||||
std::ostringstream template_def;
|
||||
template_def << "<";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
if (i > 0) {
|
||||
template_def << ", ";
|
||||
}
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
template_def << std::get<int>(arg);
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
template_def << std::get<bool>(arg);
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
template_def << get_type_string(std::get<Dtype>(arg));
|
||||
}
|
||||
i++;
|
||||
}
|
||||
template_def << ">";
|
||||
return template_def.str();
|
||||
}
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header /* = "" */,
|
||||
bool ensure_row_contiguous /* = true */,
|
||||
bool atomic_outputs /* = false */) {
|
||||
if (output_names.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
|
||||
{"dispatch_quadgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_simdgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_threads_per_threadgroup", "uint3"},
|
||||
{"grid_origin", "uint3"},
|
||||
{"grid_size", "uint3"},
|
||||
{"quadgroup_index_in_threadgroup", "uint"},
|
||||
{"quadgroups_per_threadgroup", "uint"},
|
||||
{"simdgroup_index_in_threadgroup", "uint"},
|
||||
{"simdgroups_per_threadgroup", "uint"},
|
||||
{"thread_execution_width", "uint"},
|
||||
{"thread_index_in_quadgroup", "uint"},
|
||||
{"thread_index_in_simdgroup", "uint"},
|
||||
{"thread_index_in_threadgroup", "uint"},
|
||||
{"thread_position_in_grid", "uint3"},
|
||||
{"thread_position_in_threadgroup", "uint3"},
|
||||
{"threadgroup_position_in_grid", "uint3"},
|
||||
{"threadgroups_per_grid", "uint3"},
|
||||
{"threads_per_grid", "uint3"},
|
||||
{"threads_per_simdgroup", "uint"},
|
||||
{"threads_per_threadgroup", "uint3"},
|
||||
};
|
||||
|
||||
std::vector<std::string> attributes;
|
||||
for (const auto& [attr, dtype] : metal_attributes) {
|
||||
if (source.find(attr) != std::string::npos) {
|
||||
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
|
||||
}
|
||||
}
|
||||
|
||||
return [=,
|
||||
shape_infos = std::move(shape_infos),
|
||||
attributes = std::move(attributes)](
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<Shape>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>&
|
||||
template_args = {},
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s_ = {}) {
|
||||
if (inputs.size() != input_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `inputs` to have size "
|
||||
<< input_names.size() << " but got size " << inputs.size() << "."
|
||||
<< std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_shapes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_shapes` to have size "
|
||||
<< output_names.size() << " but got size " << output_shapes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_dtypes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_dtypes` to have size "
|
||||
<< output_names.size() << " but got size " << output_dtypes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
|
||||
}
|
||||
|
||||
std::string kernel_name = "custom_kernel_" + name;
|
||||
std::string template_def = "";
|
||||
if (!template_args.empty()) {
|
||||
std::regex disallowed_chars("\\<|\\>|(, )");
|
||||
template_def = write_template(template_args);
|
||||
auto template_hash =
|
||||
std::regex_replace(template_def, disallowed_chars, "_");
|
||||
template_hash.pop_back();
|
||||
kernel_name += "_";
|
||||
kernel_name += template_hash;
|
||||
}
|
||||
|
||||
std::string kernel_source = write_signature(
|
||||
kernel_name,
|
||||
header,
|
||||
source,
|
||||
input_names,
|
||||
inputs,
|
||||
output_names,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
attributes,
|
||||
shape_infos,
|
||||
atomic_outputs);
|
||||
|
||||
if (!template_args.empty()) {
|
||||
template_def = kernel_name + template_def;
|
||||
kernel_source += "\ntemplate [[host_name(\"";
|
||||
kernel_source += kernel_name;
|
||||
kernel_source += "\")]] [[kernel]] decltype(";
|
||||
kernel_source += template_def;
|
||||
kernel_source += ") ";
|
||||
kernel_source += template_def;
|
||||
kernel_source += ";\n";
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << name << "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
|
||||
return array::make_arrays(
|
||||
std::move(output_shapes),
|
||||
std::move(output_dtypes),
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
std::move(kernel_name),
|
||||
std::move(kernel_source),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value),
|
||||
std::move(inputs));
|
||||
};
|
||||
}
|
||||
|
||||
void CustomKernel::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
@ -39,9 +353,23 @@ void CustomKernel::eval_gpu(
|
||||
}
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = name_;
|
||||
auto lib =
|
||||
d.get_library(lib_name, [this] { return metal::utils() + source_; });
|
||||
|
||||
{
|
||||
// Clear kernels from the device library cache if needed
|
||||
auto& kernel_cache = cache();
|
||||
if (auto it = kernel_cache.libraries.find(name_);
|
||||
it != kernel_cache.libraries.end()) {
|
||||
if (it->second != source_) {
|
||||
auto& d = metal::device(s.device);
|
||||
d.clear_library(name_);
|
||||
it->second = source_;
|
||||
}
|
||||
} else {
|
||||
kernel_cache.libraries.emplace(name_, source_);
|
||||
}
|
||||
}
|
||||
|
||||
auto lib = d.get_library(name_, [this] { return metal::utils() + source_; });
|
||||
auto kernel = d.get_kernel(name_, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
@ -73,6 +401,16 @@ void CustomKernel::eval_gpu(
|
||||
}
|
||||
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
auto tg_size = tx * ty * tz;
|
||||
auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (tg_size > max_tg_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "Thread group size (" << tg_size << ") is greater than "
|
||||
<< " the maximum allowed threads per threadgroup (" << max_tg_size
|
||||
<< ").";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
MTL::Size group_dims =
|
||||
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
||||
|
@ -295,8 +295,11 @@ void CommandEncoder::barrier() {
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
library_map_ = {{"mlx", load_default_library(device_)}};
|
||||
default_library_ = load_default_library(device_);
|
||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||
int ag_tens = arch_[arch_.size() - 3] - '0';
|
||||
int ag_ones = arch_[arch_.size() - 2] - '0';
|
||||
arch_gen_ = ag_tens * 10 + ag_ones;
|
||||
auto arch = arch_.back();
|
||||
switch (arch) {
|
||||
case 'p': // phone
|
||||
@ -326,11 +329,11 @@ Device::Device() {
|
||||
|
||||
Device::~Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
}
|
||||
for (auto& l : library_map_) {
|
||||
l.second->release();
|
||||
for (auto& [l, kernel_map] : library_kernels_) {
|
||||
l->release();
|
||||
for (auto& [_, k] : kernel_map) {
|
||||
k->release();
|
||||
}
|
||||
}
|
||||
stream_map_.clear();
|
||||
device_->release();
|
||||
@ -474,13 +477,24 @@ CommandEncoder& Device::get_command_encoder(int index) {
|
||||
return *stream.encoder;
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
|
||||
library_map_.insert({lib_name, new_lib});
|
||||
MTL::Library* Device::get_library(
|
||||
const std::string& name,
|
||||
const std::string& path /* = "" */) {
|
||||
{
|
||||
std::shared_lock rlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock wlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto new_lib = load_library(device_, name, path.c_str());
|
||||
library_map_.insert({name, new_lib});
|
||||
return new_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::build_library_(const std::string& source_string) {
|
||||
@ -649,6 +663,19 @@ MTL::Library* Device::get_library(
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
void Device::clear_library(const std::string& name) {
|
||||
std::unique_lock wlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
auto kernel_map_it = library_kernels_.find(it->second);
|
||||
for (auto& [_, kernel] : kernel_map_it->second) {
|
||||
kernel->release();
|
||||
}
|
||||
library_kernels_.erase(kernel_map_it);
|
||||
it->second->release();
|
||||
library_map_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
MTL::LinkedFunctions* Device::get_linked_functions_(
|
||||
const std::vector<MTL::Function*>& funcs) {
|
||||
if (funcs.empty()) {
|
||||
@ -679,6 +706,7 @@ MTL::ComputePipelineState* Device::get_kernel_(
|
||||
std::unique_lock wlock(kernel_mtx_);
|
||||
|
||||
// Try loading again to avoid loading twice
|
||||
auto& kernel_map_ = library_kernels_[mtl_lib];
|
||||
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
@ -713,6 +741,7 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
std::shared_lock lock(kernel_mtx_);
|
||||
|
||||
// Look for cached kernel
|
||||
auto& kernel_map_ = library_kernels_[mtl_lib];
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
@ -722,23 +751,11 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name /* = "mlx" */,
|
||||
const std::string& hash_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */,
|
||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
|
||||
{
|
||||
// Multiple readers allowed
|
||||
std::shared_lock lock(kernel_mtx_);
|
||||
|
||||
// Look for cached kernel
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib = get_library_(lib_name);
|
||||
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||
return get_kernel(
|
||||
base_name, default_library_, hash_name, func_consts, linked_functions);
|
||||
}
|
||||
|
||||
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
|
||||
|
@ -177,6 +177,10 @@ class Device {
|
||||
return arch_;
|
||||
}
|
||||
|
||||
int get_architecture_gen() const {
|
||||
return arch_gen_;
|
||||
}
|
||||
|
||||
void new_queue(int index);
|
||||
|
||||
MTL::CommandQueue* get_queue(Stream stream);
|
||||
@ -187,14 +191,16 @@ class Device {
|
||||
CommandEncoder& get_command_encoder(int index);
|
||||
void end_encoding(int index);
|
||||
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path = "");
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const std::string& path = "");
|
||||
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const std::function<std::string(void)>& builder);
|
||||
|
||||
void clear_library(const std::string& name);
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
@ -204,7 +210,6 @@ class Device {
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name = "mlx",
|
||||
const std::string& hash_name = "",
|
||||
const MTLFCList& func_consts = {},
|
||||
const std::vector<MTL::Function*>& linked_functions = {});
|
||||
@ -258,12 +263,16 @@ class Device {
|
||||
std::unordered_map<int32_t, DeviceStream> stream_map_;
|
||||
|
||||
std::shared_mutex kernel_mtx_;
|
||||
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
||||
|
||||
std::shared_mutex library_mtx_;
|
||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||
MTL::Library* default_library_;
|
||||
std::unordered_map<
|
||||
MTL::Library*,
|
||||
std::unordered_map<std::string, MTL::ComputePipelineState*>>
|
||||
library_kernels_;
|
||||
const MTL::ResidencySet* residency_set_{nullptr};
|
||||
std::string arch_;
|
||||
int arch_gen_;
|
||||
int max_ops_per_buffer_;
|
||||
int max_mb_per_buffer_;
|
||||
};
|
||||
|
@ -41,7 +41,11 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::unary_ops(), metal::unary());
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1);
|
||||
if (get_work_per_thread(in_type) > 1) {
|
||||
kernel_source +=
|
||||
get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
}
|
||||
kernel_source +=
|
||||
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
@ -59,11 +63,8 @@ void append_binary_kernels(
|
||||
Dtype out_type,
|
||||
const std::string op,
|
||||
std::string& kernel_source) {
|
||||
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
|
||||
const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
|
||||
{"ss", "binary_ss"},
|
||||
{"vs", "binary_vs"},
|
||||
{"sv", "binary_sv"},
|
||||
{"vv", "binary_vv"},
|
||||
{"vs2", "binary_vs2"},
|
||||
{"sv2", "binary_sv2"},
|
||||
{"vv2", "binary_vv2"},
|
||||
@ -78,6 +79,22 @@ void append_binary_kernels(
|
||||
kernel_source +=
|
||||
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
|
||||
}
|
||||
kernel_source += get_template_definition(
|
||||
"vs_" + lib_name, "binary_vs", in_t, out_t, op, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"sv_" + lib_name, "binary_sv", in_t, out_t, op, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"vv_" + lib_name, "binary_vv", in_t, out_t, op, 1);
|
||||
|
||||
if (get_work_per_thread(in_type) > 1) {
|
||||
kernel_source += get_template_definition(
|
||||
"vsn_" + lib_name, "binary_vs", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
"svn_" + lib_name, "binary_sv", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
"vvn_" + lib_name, "binary_vv", in_t, out_t, op);
|
||||
}
|
||||
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
@ -133,8 +150,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
auto t_str = get_type_string(type);
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
|
||||
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
||||
{"v", "ternary_v"},
|
||||
const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{
|
||||
{"v2", "ternary_v2"},
|
||||
{"g1large", "ternary_g_nd1"},
|
||||
{"g2large", "ternary_g_nd2"},
|
||||
@ -144,6 +160,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
kernel_source +=
|
||||
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
||||
}
|
||||
if (get_work_per_thread(type) > 1) {
|
||||
kernel_source +=
|
||||
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op);
|
||||
}
|
||||
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
@ -170,15 +193,22 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
kernel_source += metal::copy();
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source +=
|
||||
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"s_" + lib_name, "copy_s", in_type, out_type, 1);
|
||||
kernel_source +=
|
||||
get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type);
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"v_" + lib_name, "copy_v", in_type, out_type, 1);
|
||||
kernel_source +=
|
||||
get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type);
|
||||
|
||||
if (get_work_per_thread(out.dtype()) > 1) {
|
||||
kernel_source += get_template_definition(
|
||||
"sn_" + lib_name, "copy_s", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"vn_" + lib_name, "copy_v", in_type, out_type);
|
||||
}
|
||||
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
@ -697,6 +727,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
@ -719,7 +751,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
wn);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_fft_kernel(
|
||||
|
@ -205,6 +205,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
|
@ -17,8 +17,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[0], b[index + i]);
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
c[index + i] = Op()(a[0], b[index + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[index + i] = Op()(a[0], b[index + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,8 +36,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[0]);
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[0]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -43,8 +55,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[index + i]);
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[index + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[index + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -57,8 +75,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[0], b[offset + i]);
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
c[offset + i] = Op()(a[0], b[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[offset + i] = Op()(a[0], b[offset + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -71,8 +95,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[0]);
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[0]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,8 +115,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -9,11 +9,16 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
#define instantiate_binary_work_per_thread(op, tname, itype, otype) \
|
||||
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) \
|
||||
|
||||
#define instantiate_binary_base(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
@ -26,15 +31,19 @@
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
|
||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||
instantiate_binary_all(op, int64, int64_t, int64_t)
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_binary_base(op, tname, itype, otype) \
|
||||
instantiate_binary_work_per_thread(op, tname, itype, otype)
|
||||
|
||||
#define instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||
instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
|
||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||
instantiate_binary_base(op, int64, int64_t, int64_t)
|
||||
|
||||
#define instantiate_binary_float(op) \
|
||||
instantiate_binary_all(op, float16, half, half) \
|
||||
@ -44,7 +53,7 @@
|
||||
#define instantiate_binary_types(op) \
|
||||
instantiate_binary_all(op, bool_, bool, bool) \
|
||||
instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
|
||||
instantiate_binary_base(op, complex64, complex64_t, complex64_t)\
|
||||
instantiate_binary_float(op)
|
||||
|
||||
#define instantiate_binary_types_bool(op) \
|
||||
@ -52,15 +61,15 @@
|
||||
instantiate_binary_all(op, uint8, uint8_t, bool) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, bool) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, bool) \
|
||||
instantiate_binary_all(op, uint64, uint64_t, bool) \
|
||||
instantiate_binary_base(op, uint64, uint64_t, bool) \
|
||||
instantiate_binary_all(op, int8, int8_t, bool) \
|
||||
instantiate_binary_all(op, int16, int16_t, bool) \
|
||||
instantiate_binary_all(op, int32, int32_t, bool) \
|
||||
instantiate_binary_all(op, int64, int64_t, bool) \
|
||||
instantiate_binary_base(op, int64, int64_t, bool) \
|
||||
instantiate_binary_all(op, float16, half, bool) \
|
||||
instantiate_binary_all(op, float32, float, bool) \
|
||||
instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \
|
||||
instantiate_binary_all(op, complex64, complex64_t, bool)
|
||||
instantiate_binary_base(op, complex64, complex64_t, bool)
|
||||
|
||||
instantiate_binary_types(Add)
|
||||
instantiate_binary_types(Divide)
|
||||
@ -71,7 +80,7 @@ instantiate_binary_types_bool(Less)
|
||||
instantiate_binary_types_bool(LessEqual)
|
||||
instantiate_binary_types_bool(NotEqual)
|
||||
instantiate_binary_float(LogAddExp)
|
||||
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
|
||||
instantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t)
|
||||
instantiate_binary_types(Maximum)
|
||||
instantiate_binary_types(Minimum)
|
||||
instantiate_binary_types(Multiply)
|
||||
@ -84,7 +93,7 @@ instantiate_binary_float(ArcTan2)
|
||||
instantiate_binary_all(NaNEqual, float16, half, bool)
|
||||
instantiate_binary_all(NaNEqual, float32, float, bool)
|
||||
instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)
|
||||
instantiate_binary_all(NaNEqual, complex64, complex64_t, bool)
|
||||
instantiate_binary_base(NaNEqual, complex64, complex64_t, bool)
|
||||
|
||||
instantiate_binary_all(LogicalOr, bool_, bool, bool)
|
||||
instantiate_binary_all(LogicalAnd, bool_, bool, bool)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user