mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 07:01:39 +08:00
Compare commits
94 Commits
v0.25.0
...
c35f4d089a
Author | SHA1 | Date | |
---|---|---|---|
![]() |
c35f4d089a | ||
![]() |
8590c0941e | ||
![]() |
095163b8d1 | ||
![]() |
99c33d011d | ||
![]() |
62fecf3e13 | ||
![]() |
7c4eb5d03e | ||
![]() |
bae9a6b404 | ||
![]() |
004c1d8ef2 | ||
![]() |
7ebb2e0193 | ||
![]() |
9ce77798b1 | ||
![]() |
f8bad60609 | ||
![]() |
5866b3857b | ||
![]() |
1ca616844b | ||
![]() |
2e8cf0b450 | ||
![]() |
24f89173d1 | ||
![]() |
c6a20b427a | ||
![]() |
a5ac9244c4 | ||
![]() |
c763fe1be0 | ||
![]() |
52dc8c8cd5 | ||
![]() |
aede70e81d | ||
![]() |
85a8beb5e4 | ||
![]() |
0bb89e9e5f | ||
![]() |
5685ceb3c7 | ||
![]() |
0408ba0a76 | ||
![]() |
cbad6c3093 | ||
![]() |
1b021f6984 | ||
![]() |
95b7551d65 | ||
![]() |
db5a7c6192 | ||
![]() |
6ef2f67e7f | ||
![]() |
f76ee1ffd2 | ||
![]() |
54a71f270a | ||
![]() |
55b4062dd8 | ||
![]() |
79071bfba4 | ||
![]() |
7774b87cbd | ||
![]() |
35c87741cf | ||
![]() |
4cbe605214 | ||
![]() |
ab8883dd55 | ||
![]() |
eebe73001a | ||
![]() |
0359bf02c9 | ||
![]() |
237f9e58a8 | ||
![]() |
8576e6fe36 | ||
![]() |
0654543dcc | ||
![]() |
48ef3e74e2 | ||
![]() |
7d4b378952 | ||
![]() |
7ff5c41e06 | ||
![]() |
602f43e3d1 | ||
![]() |
a2cadb8218 | ||
![]() |
c1eb9d05d9 | ||
![]() |
cf6c939e86 | ||
![]() |
130df35e1b | ||
![]() |
0751263dec | ||
![]() |
eca2f3eb97 | ||
![]() |
3aa9cf3f9e | ||
![]() |
8f3d208dce | ||
![]() |
caaa3f1f8c | ||
![]() |
659a51919f | ||
![]() |
6661387066 | ||
![]() |
a7fae8a176 | ||
![]() |
0cae0bdac8 | ||
![]() |
5a1a5d5ed1 | ||
![]() |
1683975acf | ||
![]() |
af705590ac | ||
![]() |
825124af8f | ||
![]() |
9c5e7da507 | ||
![]() |
481349495b | ||
![]() |
9daa6b003f | ||
![]() |
a3a632d567 | ||
![]() |
e496c5a4b4 | ||
![]() |
ea890d8710 | ||
![]() |
aa5d84f102 | ||
![]() |
f1606486d2 | ||
![]() |
87720a8908 | ||
![]() |
bb6565ef14 | ||
![]() |
7bb063bcb3 | ||
![]() |
b36dd472bb | ||
![]() |
167b759a38 | ||
![]() |
99b9868859 | ||
![]() |
6b2d5448f2 | ||
![]() |
eaf709b83e | ||
![]() |
f0e70afff0 | ||
![]() |
86984cad68 | ||
![]() |
fbc89e3ced | ||
![]() |
38c1e720c2 | ||
![]() |
600e87e03c | ||
![]() |
3836445241 | ||
![]() |
1d2c9d6a07 | ||
![]() |
e8ac6bd2f5 | ||
![]() |
fdadc4f22c | ||
![]() |
79b527f45f | ||
![]() |
dc4eada7f0 | ||
![]() |
70ebc3b598 | ||
![]() |
b13f2aed16 | ||
![]() |
5f04c0f818 | ||
![]() |
55935ccae7 |
@@ -212,6 +212,29 @@ jobs:
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||
|
||||
cuda_build_and_test:
|
||||
machine:
|
||||
image: linux-cuda-12:default
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
python -m venv env
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
pip install -e ".[dev]"
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
python_version:
|
||||
@@ -348,6 +371,7 @@ workflows:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test
|
||||
- cuda_build_and_test
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
@@ -455,6 +479,8 @@ workflows:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
- cuda_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
when:
|
||||
and:
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@@ -36,6 +36,7 @@ share/python-wheels/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
uv.lock
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
|
@@ -34,6 +34,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
@@ -83,6 +84,10 @@ if(MLX_BUILD_METAL)
|
||||
set(QUARTZ_LIB "-framework QuartzCore")
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
@@ -226,6 +231,9 @@ target_include_directories(
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||
$<INSTALL_INTERFACE:include>)
|
||||
|
||||
# Do not add mlx_EXPORTS define for shared library.
|
||||
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
|
@@ -1,4 +1,6 @@
|
||||
include CMakeLists.txt
|
||||
include mlx.pc.in
|
||||
recursive-include mlx/ *
|
||||
include cmake/*
|
||||
include python/src/*
|
||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||
|
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
|
107
benchmarks/python/conv_unaligned_bench.py
Normal file
107
benchmarks/python/conv_unaligned_bench.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtype = "float32"
|
||||
shapes = (
|
||||
(4, 32, 32, 21, 3, 3, 128),
|
||||
(4, 32, 32, 21, 3, 3, 37),
|
||||
(4, 32, 32, 370, 3, 3, 370),
|
||||
(4, 32, 32, 370, 7, 7, 128),
|
||||
(2, 320, 640, 21, 7, 7, 21),
|
||||
)
|
||||
for N, H, W, C, kh, kw, O in shapes:
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
@@ -1,5 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
@@ -18,51 +20,63 @@ def layer_norm(x, w, b, eps):
|
||||
return y
|
||||
|
||||
|
||||
def time_layer_norm():
|
||||
def time_layer_norm(N, dt):
|
||||
L = 1024
|
||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x, w, b):
|
||||
def layer_norm_loop(f, x, w, b):
|
||||
for _ in range(32):
|
||||
x = f(x, w, b)
|
||||
return x
|
||||
|
||||
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
||||
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
||||
|
||||
def layer_norm_grad_loop(g, x, w, b):
|
||||
gx, gw, gb = x, w, b
|
||||
for _ in range(32):
|
||||
gx, gw, gb = g(gx, gw, gb, y)
|
||||
return gx, gw, gb
|
||||
|
||||
time_fn(layer_norm_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
||||
|
||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0,))
|
||||
g2 = mx.grad(f2, argnums=(0,))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x):
|
||||
def layer_norm_grad_x_loop(g, x):
|
||||
gx = x
|
||||
for _ in range(32):
|
||||
gx = g(gx, y)
|
||||
return gx
|
||||
|
||||
time_fn(layer_norm_loop, g1, x)
|
||||
time_fn(layer_norm_loop, g2, x)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x)
|
||||
time_fn(layer_norm_grad_x_loop, g1, x)
|
||||
time_fn(layer_norm_grad_x_loop, g2, x)
|
||||
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_layer_norm()
|
||||
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
||||
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
||||
print(dt, n)
|
||||
time_layer_norm(n, dt)
|
||||
|
@@ -11,13 +11,14 @@ include(CMakeParseArguments)
|
||||
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||
# files (like headers)
|
||||
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
|
||||
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
|
||||
#
|
||||
# clang format on
|
||||
|
||||
macro(mlx_build_metallib)
|
||||
# Parse args
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
|
||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
@@ -26,6 +27,10 @@ macro(mlx_build_metallib)
|
||||
|
||||
# Collect compile options
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
|
||||
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
|
||||
-frecord-sources)
|
||||
endif()
|
||||
|
||||
# Prepare metallib build command
|
||||
add_custom_command(
|
||||
|
@@ -10,7 +10,7 @@ import mlx.core as mx
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "MLX"
|
||||
copyright = "2023, MLX Contributors"
|
||||
copyright = "2023, Apple"
|
||||
author = "MLX Contributors"
|
||||
version = ".".join(mx.__version__.split(".")[:3])
|
||||
release = version
|
||||
|
@@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
||||
Simple Example
|
||||
--------------
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
@@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
Every time you make a kernel, a new Metal library is created and possibly
|
||||
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||
:func:`fast.metal_kernel` and then use it many times.
|
||||
|
||||
.. note::
|
||||
We are only required to pass the body of the Metal kernel in ``source``.
|
||||
Only pass the body of the Metal kernel in ``source``. The function
|
||||
signature is generated automatically.
|
||||
|
||||
The full function signature will be generated using:
|
||||
|
||||
@@ -78,44 +86,51 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
||||
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
||||
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
||||
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
||||
dimension should be less than or equal to the corresponding grid dimension.
|
||||
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
||||
generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
-------------------
|
||||
|
||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
||||
when indexing.
|
||||
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
||||
is ``True`` by default. This will copy the array inputs if needed
|
||||
before the kernel is launched to ensure that the memory layout is row
|
||||
contiguous. Generally this makes writing the kernel easier, since we don't
|
||||
have to worry about gaps or the ordering of the dims when indexing.
|
||||
|
||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
||||
input array ``a`` if any are present in ``source``.
|
||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
||||
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
||||
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
||||
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
||||
the right elements for each thread.
|
||||
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
||||
relying on a copy from ``ensure_row_contiguous``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
@@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
|
||||
return output
|
||||
return output
|
||||
|
||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
||||
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
||||
to write a fast GPU kernel for both the forward and backward passes.
|
||||
|
||||
First we'll implement the forward pass as a fused kernel:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
For a reasonably sized input such as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
|
||||
On an M1 Max, we see a big performance improvement:
|
||||
|
||||
@@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
|
||||
Grid Sample VJP
|
||||
---------------
|
||||
|
||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
||||
its custom vjp transform so MLX can differentiate it.
|
||||
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
||||
define its custom vjp transform so MLX can differentiate it.
|
||||
|
||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
||||
requires a few extra :func:`fast.metal_kernel` features:
|
||||
|
||||
* ``init_value=0``
|
||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||
@@ -299,128 +316,129 @@ We can then implement the backwards pass as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
|
||||
There's an even larger speed up for the vjp:
|
||||
|
||||
|
@@ -397,11 +397,11 @@ below.
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
auto kernel = d.get_kernel(kname.str(), lib);
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@@ -19,6 +19,8 @@ Array
|
||||
array.ndim
|
||||
array.shape
|
||||
array.size
|
||||
array.real
|
||||
array.imag
|
||||
array.abs
|
||||
array.all
|
||||
array.any
|
||||
|
@@ -20,3 +20,5 @@ FFT
|
||||
irfft2
|
||||
rfftn
|
||||
irfftn
|
||||
fftshift
|
||||
ifftshift
|
||||
|
@@ -16,6 +16,8 @@ Linear Algebra
|
||||
cross
|
||||
qr
|
||||
svd
|
||||
eigvals
|
||||
eig
|
||||
eigvalsh
|
||||
eigh
|
||||
lu
|
||||
|
@@ -172,11 +172,11 @@ void Axpby::eval_gpu(
|
||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
auto kernel = d.get_kernel(kname.str(), lib);
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@@ -5,6 +5,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
@@ -20,7 +21,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||
|
||||
# Define MLX_VERSION only in the version.cpp file.
|
||||
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
||||
|
||||
@@ -48,5 +49,19 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if(MLX_BUILD_METAL)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||
else()
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||
endif()
|
||||
|
12
mlx/array.h
12
mlx/array.h
@@ -224,6 +224,10 @@ class array {
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
Data& operator=(const Data& d) = delete;
|
||||
Data(Data&& o) : buffer(o.buffer), d(o.d) {
|
||||
o.buffer = allocator::Buffer(nullptr);
|
||||
o.d = [](allocator::Buffer) {};
|
||||
}
|
||||
~Data() {
|
||||
d(buffer);
|
||||
}
|
||||
@@ -339,11 +343,11 @@ class array {
|
||||
return allocator::allocator().size(buffer());
|
||||
}
|
||||
|
||||
// Return a copy of the shared pointer
|
||||
// to the array::Data struct
|
||||
std::shared_ptr<Data> data_shared_ptr() const {
|
||||
// Return the shared pointer to the array::Data struct
|
||||
const std::shared_ptr<Data>& data_shared_ptr() const {
|
||||
return array_desc_->data;
|
||||
}
|
||||
|
||||
// Return a raw pointer to the arrays data
|
||||
template <typename T>
|
||||
T* data() {
|
||||
@@ -356,7 +360,7 @@ class array {
|
||||
}
|
||||
|
||||
enum Status {
|
||||
// The ouptut of a computation which has not been scheduled.
|
||||
// The output of a computation which has not been scheduled.
|
||||
// For example, the status of `x` in `auto x = a + b`.
|
||||
unscheduled,
|
||||
|
||||
|
157
mlx/backend/common/buffer_cache.h
Normal file
157
mlx/backend/common/buffer_cache.h
Normal file
@@ -0,0 +1,157 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
class BufferCache {
|
||||
public:
|
||||
BufferCache(
|
||||
size_t page_size,
|
||||
std::function<size_t(T*)> get_size,
|
||||
std::function<void(T*)> free)
|
||||
: page_size_(page_size),
|
||||
get_size_(std::move(get_size)),
|
||||
free_(std::move(free)) {}
|
||||
|
||||
~BufferCache() {
|
||||
clear();
|
||||
}
|
||||
|
||||
BufferCache(const BufferCache&) = delete;
|
||||
BufferCache& operator=(const BufferCache&) = delete;
|
||||
|
||||
T* reuse_from_cache(size_t size) {
|
||||
// Find the closest buffer in pool.
|
||||
auto it = buffer_pool_.lower_bound(size);
|
||||
if (it == buffer_pool_.end() ||
|
||||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Collect from the cache.
|
||||
T* buf = it->second->buf;
|
||||
pool_size_ -= it->first;
|
||||
|
||||
// Remove from record.
|
||||
remove_from_list(it->second);
|
||||
buffer_pool_.erase(it);
|
||||
return buf;
|
||||
}
|
||||
|
||||
void recycle_to_cache(T* buf) {
|
||||
assert(buf);
|
||||
// Add to cache.
|
||||
BufferHolder* bh = new BufferHolder(buf);
|
||||
add_at_head(bh);
|
||||
size_t size = get_size_(buf);
|
||||
pool_size_ += size;
|
||||
buffer_pool_.emplace(size, bh);
|
||||
}
|
||||
|
||||
int release_cached_buffers(size_t min_bytes_to_free) {
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
return clear();
|
||||
} else {
|
||||
int n_release = 0;
|
||||
size_t total_bytes_freed = 0;
|
||||
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
// Release buffer.
|
||||
size_t size = get_size_(tail_->buf);
|
||||
total_bytes_freed += size;
|
||||
free_(tail_->buf);
|
||||
n_release++;
|
||||
|
||||
// Remove from record.
|
||||
auto its = buffer_pool_.equal_range(size);
|
||||
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
|
||||
return el.second == tail_;
|
||||
});
|
||||
assert(it != buffer_pool_.end());
|
||||
buffer_pool_.erase(it);
|
||||
remove_from_list(tail_);
|
||||
}
|
||||
|
||||
pool_size_ -= total_bytes_freed;
|
||||
return n_release;
|
||||
}
|
||||
}
|
||||
|
||||
int clear() {
|
||||
int n_release = 0;
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
free_(holder->buf);
|
||||
n_release++;
|
||||
delete holder;
|
||||
}
|
||||
buffer_pool_.clear();
|
||||
pool_size_ = 0;
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
return n_release;
|
||||
}
|
||||
|
||||
size_t cache_size() const {
|
||||
return pool_size_;
|
||||
}
|
||||
|
||||
size_t page_size() const {
|
||||
return page_size_;
|
||||
}
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
public:
|
||||
explicit BufferHolder(T* buf_) : buf(buf_) {}
|
||||
|
||||
BufferHolder* prev{nullptr};
|
||||
BufferHolder* next{nullptr};
|
||||
T* buf;
|
||||
};
|
||||
|
||||
void add_at_head(BufferHolder* to_add) {
|
||||
if (!head_) {
|
||||
head_ = to_add;
|
||||
tail_ = to_add;
|
||||
} else {
|
||||
head_->prev = to_add;
|
||||
to_add->next = head_;
|
||||
head_ = to_add;
|
||||
}
|
||||
}
|
||||
|
||||
void remove_from_list(BufferHolder* to_remove) {
|
||||
if (to_remove->prev && to_remove->next) { // if middle
|
||||
to_remove->prev->next = to_remove->next;
|
||||
to_remove->next->prev = to_remove->prev;
|
||||
} else if (to_remove->prev && to_remove == tail_) { // if tail
|
||||
tail_ = to_remove->prev;
|
||||
tail_->next = nullptr;
|
||||
} else if (to_remove == head_ && to_remove->next) { // if head
|
||||
head_ = to_remove->next;
|
||||
head_->prev = nullptr;
|
||||
} else if (to_remove == head_ && to_remove == tail_) { // if only element
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
}
|
||||
|
||||
delete to_remove;
|
||||
}
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_{nullptr};
|
||||
BufferHolder* tail_{nullptr};
|
||||
size_t pool_size_{0};
|
||||
|
||||
const size_t page_size_;
|
||||
std::function<size_t(T*)> get_size_;
|
||||
std::function<void(T*)> free_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,8 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -79,55 +78,6 @@ std::string get_type_string(Dtype d) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
||||
NodeNamer namer;
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
// Fill the input names. This is not really necessary, I just like having A,
|
||||
// B, C, ... as the inputs.
|
||||
for (auto& x : inputs) {
|
||||
namer.get_name(x);
|
||||
}
|
||||
|
||||
// The primitives describing the tape. For unary and binary primitives this
|
||||
// must be enough to describe the full computation.
|
||||
for (auto& a : tape) {
|
||||
// name and type of output
|
||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||
// computation performed
|
||||
a.primitive().print(os);
|
||||
// name of inputs to the function
|
||||
for (auto& inp : a.inputs()) {
|
||||
os << namer.get_name(inp);
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << (is_scalar(x) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
os << kindof(x.dtype()) << x.itemsize();
|
||||
}
|
||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const Shape& shape) {
|
||||
@@ -159,8 +109,7 @@ bool compiled_check_contiguity(
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
@@ -175,8 +124,7 @@ void compiled_allocate_outputs(
|
||||
// - Donatable
|
||||
// - Not a constant
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
in.is_donatable() && is_constant(i)) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
@@ -204,7 +152,7 @@ void compiled_allocate_outputs(
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
is_constant(i)) {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
@@ -216,4 +164,74 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
const std::vector<array>& inputs,
|
||||
const array& out,
|
||||
const std::function<bool(size_t)>& is_constant) {
|
||||
const Shape& shape = out.shape();
|
||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||
if (contiguous) {
|
||||
return {true, shape, {}};
|
||||
}
|
||||
|
||||
std::vector<Strides> strides_vec{out.strides()};
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants.
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip scalar inputs.
|
||||
const auto& x = inputs[i];
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
Strides xstrides;
|
||||
size_t j = 0;
|
||||
for (; j < shape.size() - x.ndim(); ++j) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
strides_vec.push_back(std::move(xstrides));
|
||||
}
|
||||
|
||||
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
|
||||
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
|
||||
}
|
||||
|
||||
bool compiled_use_large_index(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
for (const auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
} else {
|
||||
size_t max_size = 0;
|
||||
for (const auto& o : outputs) {
|
||||
max_size = std::max(max_size, o.size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,9 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -14,12 +13,6 @@ inline bool is_static_cast(const Primitive& p) {
|
||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids);
|
||||
|
||||
std::string get_type_string(Dtype d);
|
||||
|
||||
template <typename T>
|
||||
@@ -60,8 +53,19 @@ bool compiled_check_contiguity(
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous);
|
||||
|
||||
// Collapse contiguous dims ignoring scalars and constants.
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
const std::vector<array>& inputs,
|
||||
const array& out,
|
||||
const std::function<bool(size_t)>& is_constant);
|
||||
|
||||
// Return whether the kernel should use large index.
|
||||
bool compiled_use_large_index(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -26,7 +26,7 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
// have the same size, then the input buffer can hold the output.
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
return true;
|
||||
} else {
|
||||
|
@@ -99,7 +99,11 @@ inline std::pair<int, int> decompose_hadamard(int n) {
|
||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||
}
|
||||
}
|
||||
if (n > (1 << 26)) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] Only supports n = m*2^k where k <= 26");
|
||||
}
|
||||
return {n, m};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
78
mlx/backend/common/matmul.h
Normal file
78
mlx/backend/common/matmul.h
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||
const array& a,
|
||||
const array& b) {
|
||||
// Get and check the shape for the batched dims
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
if (A_bshape != B_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
<< a.shape() << ", B " << b.shape() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||
|
||||
auto a_batch_strides = batch_strides[0];
|
||||
auto b_batch_strides = batch_strides[1];
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
batch_shape.push_back(1);
|
||||
a_batch_strides.push_back(0);
|
||||
b_batch_strides.push_back(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
|
||||
}
|
||||
|
||||
inline std::tuple<Shape, Strides, Strides, Strides>
|
||||
collapse_batches(const array& a, const array& b, const array& c) {
|
||||
// Get and check the shape for the batched dims
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
|
||||
if (A_bshape != B_bshape || A_bshape != C_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||
|
||||
auto A_batch_stride = batch_strides[0];
|
||||
auto B_batch_stride = batch_strides[1];
|
||||
auto C_batch_stride = batch_strides[2];
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
batch_shape.push_back(1);
|
||||
A_batch_stride.push_back(0);
|
||||
B_batch_stride.push_back(0);
|
||||
C_batch_stride.push_back(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(
|
||||
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
26
mlx/backend/common/unary.h
Normal file
26
mlx/backend/common/unary.h
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline void set_unary_output_data(const array& in, array& out) {
|
||||
if (in.flags().contiguous) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,9 +1,16 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive) {
|
||||
std::ostringstream op_t;
|
||||
primitive->print(op_t);
|
||||
return op_t.str();
|
||||
}
|
||||
|
||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const std::vector<Strides>& strides,
|
||||
@@ -101,4 +108,105 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
||||
}
|
||||
|
||||
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
||||
int pows[3] = {0, 0, 0};
|
||||
int sum = 0;
|
||||
while (true) {
|
||||
int presum = sum;
|
||||
// Check all the pows
|
||||
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||
pows[0]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||
pows[1]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||
pows[2]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == presum || sum == pow2) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
|
||||
}
|
||||
|
||||
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
|
||||
// Dims with strides of 0 are ignored as they
|
||||
// correspond to broadcasted dimensions
|
||||
size_t grid_x = 1;
|
||||
size_t grid_y = 1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (strides[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
if (grid_x * shape[i] < UINT32_MAX) {
|
||||
grid_x *= shape[i];
|
||||
} else {
|
||||
grid_y *= shape[i];
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
if (grid_y > grid_x) {
|
||||
std::swap(grid_x, grid_y);
|
||||
}
|
||||
return std::make_tuple(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
Dims get_2d_grid_dims_common(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor) {
|
||||
// Compute the 2d grid dimensions such that the total size of the grid is
|
||||
// divided by divisor.
|
||||
size_t grid_x = 1;
|
||||
size_t grid_y = 1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (strides[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// No need to add this shape we can just remove it from the divisor.
|
||||
if (divisor % shape[i] == 0) {
|
||||
divisor /= shape[i];
|
||||
continue;
|
||||
}
|
||||
|
||||
if (grid_x * shape[i] < UINT32_MAX) {
|
||||
grid_x *= shape[i];
|
||||
} else {
|
||||
grid_y *= shape[i];
|
||||
}
|
||||
|
||||
if (divisor > 1) {
|
||||
if (grid_x % divisor == 0) {
|
||||
grid_x /= divisor;
|
||||
divisor = 1;
|
||||
} else if (grid_y % divisor == 0) {
|
||||
grid_y /= divisor;
|
||||
divisor = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
if (grid_y > grid_x) {
|
||||
std::swap(grid_x, grid_y);
|
||||
}
|
||||
return std::make_tuple(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -2,12 +2,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive);
|
||||
|
||||
inline int64_t
|
||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||
int64_t loc = 0;
|
||||
@@ -70,6 +73,28 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
const array& a,
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
|
||||
// Compute the thread block dimensions which fit the given
|
||||
// input dimensions.
|
||||
// - The thread block dimensions will be powers of two
|
||||
// - The thread block size will be less than 2^pow2
|
||||
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
|
||||
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||
|
||||
// Computes a 2D grid where each element is < UINT_MAX
|
||||
// Assumes:
|
||||
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
||||
// - shape and strides correspond to a contiguous (no holes) but
|
||||
// possibly broadcasted array
|
||||
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
|
||||
|
||||
// Same as above but we do an implicit division with divisor.
|
||||
// Basically, equivalent to factorizing
|
||||
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
||||
Dims get_2d_grid_dims_common(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor);
|
||||
|
||||
struct ContiguousIterator {
|
||||
inline void step() {
|
||||
int dims = shape_.size();
|
||||
@@ -165,4 +190,11 @@ void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
|
||||
vec.erase(std::next(vec.begin(), index));
|
||||
return vec;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -40,11 +40,13 @@ add_dependencies(mlx cpu_compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
|
@@ -14,10 +14,8 @@ template <typename InT, typename OpT>
|
||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
auto axis_size = in.shape()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
Strides strides = in.strides();
|
||||
Shape shape = in.shape();
|
||||
strides.erase(strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
Strides strides = remove_index(in.strides(), axis);
|
||||
Shape shape = remove_index(in.shape(), axis);
|
||||
auto in_ptr = in.data<InT>();
|
||||
auto out_ptr = out.data<uint32_t>();
|
||||
|
||||
|
11
mlx/backend/cpu/available.cpp
Normal file
11
mlx/backend/cpu/available.cpp
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/available.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cpu
|
9
mlx/backend/cpu/available.h
Normal file
9
mlx/backend/cpu/available.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::cpu
|
@@ -172,9 +172,12 @@ void binary_float(
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[binary_float] Only supports non-complex floating point types.");
|
||||
"[binary_float] Only supports floating point types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@@ -40,7 +40,10 @@ struct CompilerCache {
|
||||
std::shared_mutex mtx;
|
||||
};
|
||||
|
||||
static CompilerCache cache{};
|
||||
static CompilerCache& cache() {
|
||||
static CompilerCache cache_;
|
||||
return cache_;
|
||||
};
|
||||
|
||||
// GPU compile is always available if the GPU is available and since we are in
|
||||
// this file CPU compile is also available.
|
||||
@@ -56,14 +59,16 @@ void* compile(
|
||||
const std::string& kernel_name,
|
||||
const std::function<std::string(void)>& source_builder) {
|
||||
{
|
||||
std::shared_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
std::shared_lock lock(cache().mtx);
|
||||
if (auto it = cache().kernels.find(kernel_name);
|
||||
it != cache().kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
std::unique_lock lock(cache().mtx);
|
||||
if (auto it = cache().kernels.find(kernel_name);
|
||||
it != cache().kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::string source_code = source_builder();
|
||||
@@ -120,10 +125,10 @@ void* compile(
|
||||
}
|
||||
|
||||
// load library
|
||||
cache.libs.emplace_back(shared_lib_path);
|
||||
cache().libs.emplace_back(shared_lib_path);
|
||||
|
||||
// Load function
|
||||
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
||||
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
|
||||
if (!fun) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||
@@ -131,7 +136,7 @@ void* compile(
|
||||
<< dlerror();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
cache.kernels.insert({kernel_name, fun});
|
||||
cache().kernels.insert({kernel_name, fun});
|
||||
return fun;
|
||||
}
|
||||
|
||||
@@ -141,18 +146,9 @@ inline void build_kernel(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous,
|
||||
int ndim) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
|
||||
// Constants are scalars that are captured by value and cannot change
|
||||
auto is_constant = [&constant_ids](const array& x) {
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
NodeNamer namer;
|
||||
|
||||
#ifdef _MSC_VER
|
||||
@@ -165,14 +161,15 @@ inline void build_kernel(
|
||||
|
||||
// Add the input arguments
|
||||
int cnt = 0;
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants from the input list
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
auto tstr = get_type_string(x.dtype());
|
||||
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
||||
<< "];" << std::endl;
|
||||
@@ -206,10 +203,11 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
for (auto& x : inputs) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||
print_constant(os, x);
|
||||
os << ";" << std::endl;
|
||||
@@ -259,8 +257,9 @@ inline void build_kernel(
|
||||
} else {
|
||||
for (int d = ndim - 1; d >= 0; --d) {
|
||||
// Update pointers
|
||||
for (auto& x : inputs) {
|
||||
if (is_constant(x) || is_scalar(x)) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
if (is_constant(i) || is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
auto& xname = namer.get_name(x);
|
||||
@@ -282,65 +281,37 @@ inline void build_kernel(
|
||||
void Compiled::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
if (kernel_lib_.empty()) {
|
||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||
}
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
auto [contiguous, shape, strides] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Collect function input arguments.
|
||||
std::vector<void*> args;
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
const auto& x = inputs[i];
|
||||
encoder.set_input_array(x);
|
||||
args.push_back((void*)x.data<void>());
|
||||
|
||||
if (contiguous || is_scalar(x)) {
|
||||
continue;
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
args.push_back(strides[strides_index++].data());
|
||||
}
|
||||
|
||||
// Broadcast the input to the output shape.
|
||||
std::vector<size_t> xstrides;
|
||||
int j = 0;
|
||||
for (; j < shape.size() - x.ndim(); j++) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
strides.push_back(std::move(xstrides));
|
||||
args.push_back(strides.back().data());
|
||||
}
|
||||
|
||||
// Get the kernel name from the lib
|
||||
int ndim = shape.size();
|
||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||
if (!contiguous) {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
kernel_name += std::to_string(ndim);
|
||||
}
|
||||
|
||||
// Get the function
|
||||
auto fn_ptr = compile(kernel_name, [&]() {
|
||||
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
|
||||
std::ostringstream kernel;
|
||||
kernel << get_kernel_preamble() << std::endl;
|
||||
kernel << "extern \"C\" {" << std::endl;
|
||||
@@ -350,7 +321,7 @@ void Compiled::eval_cpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
contiguous,
|
||||
ndim);
|
||||
// Close extern "C"
|
||||
@@ -358,26 +329,22 @@ void Compiled::eval_cpu(
|
||||
return kernel.str();
|
||||
});
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
|
||||
for (auto& x : outputs) {
|
||||
args.push_back(x.data<void>());
|
||||
encoder.set_output_array(x);
|
||||
}
|
||||
Shape out_shape;
|
||||
if (!contiguous) {
|
||||
out_shape = outputs[0].shape();
|
||||
args.push_back((void*)out_shape.data());
|
||||
args.push_back((void*)shape.data());
|
||||
} else {
|
||||
args.push_back((void*)outputs[0].data_size());
|
||||
}
|
||||
auto fun = (void (*)(void**))fn_ptr;
|
||||
encoder.dispatch(
|
||||
[fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
|
||||
encoder.dispatch([fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
shape = std::move(shape)]() mutable { fun(args.data()); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -22,7 +22,8 @@ void slow_conv_1D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@@ -60,7 +61,8 @@ void slow_conv_1D(
|
||||
out_stride_O = out.strides()[2],
|
||||
|
||||
flip,
|
||||
padding = padding[0],
|
||||
padding_lo = padding_lo[0],
|
||||
padding_hi = padding_hi[0],
|
||||
wt_stride = wt_strides[0],
|
||||
wt_dilation = wt_dilation[0],
|
||||
in_dilation = in_dilation[0]]() mutable {
|
||||
@@ -77,7 +79,7 @@ void slow_conv_1D(
|
||||
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
||||
|
||||
int wh_flip = flip ? (wH - wh - 1) : wh;
|
||||
int ih = oh * wt_stride - padding + wh_flip * wt_dilation;
|
||||
int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation;
|
||||
|
||||
auto ih_div = std::div(ih, in_dilation);
|
||||
|
||||
@@ -109,7 +111,8 @@ void slow_conv_2D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@@ -120,230 +123,235 @@ void slow_conv_2D(
|
||||
encoder.set_input_array(wt);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
encoder.dispatch([st_wt_ptr = wt.data<T>(),
|
||||
st_in_ptr = in.data<T>(),
|
||||
st_out_ptr = out.data<T>(),
|
||||
encoder.dispatch(
|
||||
[st_wt_ptr = wt.data<T>(),
|
||||
st_in_ptr = in.data<T>(),
|
||||
st_out_ptr = out.data<T>(),
|
||||
|
||||
N = in.shape(
|
||||
0), // Batch size, should be the same as out.shape(0)
|
||||
iH = 1 +
|
||||
in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
|
||||
iW = 1 +
|
||||
in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
|
||||
C = in.shape(3), // In channels
|
||||
oH = out.shape(1), // Output spatial dim
|
||||
oW = out.shape(2), // Output spatial dim
|
||||
O = wt.shape(0), // Out channels
|
||||
wH = wt.shape(1), // Weight spatial dim
|
||||
wW = wt.shape(2), // Weight spatial dim
|
||||
N = in.shape(0), // Batch size, should be the same as out.shape(0)
|
||||
iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
|
||||
iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
|
||||
C = in.shape(3), // In channels
|
||||
oH = out.shape(1), // Output spatial dim
|
||||
oW = out.shape(2), // Output spatial dim
|
||||
O = wt.shape(0), // Out channels
|
||||
wH = wt.shape(1), // Weight spatial dim
|
||||
wW = wt.shape(2), // Weight spatial dim
|
||||
|
||||
groups = in.shape(3) / wt.shape(3),
|
||||
C_per_group = wt.shape(3),
|
||||
groups = in.shape(3) / wt.shape(3),
|
||||
C_per_group = wt.shape(3),
|
||||
|
||||
in_stride_N = in.strides()[0],
|
||||
in_stride_H = in.strides()[1],
|
||||
in_stride_W = in.strides()[2],
|
||||
in_stride_C = in.strides()[3],
|
||||
in_stride_N = in.strides()[0],
|
||||
in_stride_H = in.strides()[1],
|
||||
in_stride_W = in.strides()[2],
|
||||
in_stride_C = in.strides()[3],
|
||||
|
||||
wt_stride_O = wt.strides()[0],
|
||||
wt_stride_H = wt.strides()[1],
|
||||
wt_stride_W = wt.strides()[2],
|
||||
wt_stride_C = wt.strides()[3],
|
||||
wt_stride_O = wt.strides()[0],
|
||||
wt_stride_H = wt.strides()[1],
|
||||
wt_stride_W = wt.strides()[2],
|
||||
wt_stride_C = wt.strides()[3],
|
||||
|
||||
out_stride_N = out.strides()[0],
|
||||
out_stride_H = out.strides()[1],
|
||||
out_stride_W = out.strides()[2],
|
||||
out_stride_O = out.strides()[3],
|
||||
out_stride_N = out.strides()[0],
|
||||
out_stride_H = out.strides()[1],
|
||||
out_stride_W = out.strides()[2],
|
||||
out_stride_O = out.strides()[3],
|
||||
|
||||
padding,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
flip]() mutable {
|
||||
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
flip]() mutable {
|
||||
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
|
||||
|
||||
const int O_per_group = O / groups;
|
||||
auto pt_conv_no_checks = [&](const T* in_ptr,
|
||||
const T* wt_ptr,
|
||||
T* out_ptr,
|
||||
int oh,
|
||||
int ow) {
|
||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||
int ih_base = oh * wt_strides[0] - padding[0];
|
||||
int iw_base = ow * wt_strides[1] - padding[1];
|
||||
const int O_per_group = O / groups;
|
||||
auto pt_conv_no_checks =
|
||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||
int ih_base = oh * wt_strides[0] - padding_lo[0];
|
||||
int iw_base = ow * wt_strides[1] - padding_lo[1];
|
||||
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
float r = 0.;
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
|
||||
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||
static_cast<float>(
|
||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||
} // c
|
||||
} // ww
|
||||
} // wh
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||
++c) {
|
||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||
static_cast<float>(
|
||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||
} // c
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
} // g
|
||||
};
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
} // g
|
||||
};
|
||||
|
||||
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
|
||||
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
|
||||
|
||||
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
|
||||
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
|
||||
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
|
||||
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
|
||||
|
||||
int f_wgt_jump_h =
|
||||
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
|
||||
int f_wgt_jump_w =
|
||||
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
||||
int f_wgt_jump_h =
|
||||
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
|
||||
int f_wgt_jump_w =
|
||||
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
||||
|
||||
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
||||
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
||||
int f_out_jump_h =
|
||||
std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
||||
int f_out_jump_w =
|
||||
std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
||||
|
||||
std::vector<int> base_h(f_out_jump_h);
|
||||
std::vector<int> base_w(f_out_jump_w);
|
||||
std::vector<int> base_h(f_out_jump_h);
|
||||
std::vector<int> base_w(f_out_jump_w);
|
||||
|
||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
|
||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||
int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h;
|
||||
|
||||
int wh_base = 0;
|
||||
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
|
||||
wh_base++;
|
||||
ih_loop += jump_h;
|
||||
}
|
||||
int wh_base = 0;
|
||||
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
|
||||
wh_base++;
|
||||
ih_loop += jump_h;
|
||||
}
|
||||
|
||||
base_h[i] = wh_base;
|
||||
}
|
||||
base_h[i] = wh_base;
|
||||
}
|
||||
|
||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
|
||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||
int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w;
|
||||
|
||||
int ww_base = 0;
|
||||
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
|
||||
ww_base++;
|
||||
iw_loop += jump_w;
|
||||
}
|
||||
int ww_base = 0;
|
||||
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
|
||||
ww_base++;
|
||||
iw_loop += jump_w;
|
||||
}
|
||||
|
||||
base_w[j] = ww_base;
|
||||
}
|
||||
base_w[j] = ww_base;
|
||||
}
|
||||
|
||||
auto pt_conv_all_checks =
|
||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||
auto pt_conv_all_checks =
|
||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||
|
||||
int ih_base = oh * wt_strides[0] - padding[0];
|
||||
int iw_base = ow * wt_strides[1] - padding[1];
|
||||
int ih_base = oh * wt_strides[0] - padding_lo[0];
|
||||
int iw_base = ow * wt_strides[1] - padding_lo[1];
|
||||
|
||||
int wh_base = base_h[oh % f_out_jump_h];
|
||||
int ww_base = base_w[ow % f_out_jump_w];
|
||||
int wh_base = base_h[oh % f_out_jump_h];
|
||||
int ww_base = base_w[ow % f_out_jump_w];
|
||||
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
float r = 0.;
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
|
||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H +
|
||||
iw_dil * in_stride_W;
|
||||
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||
++c) {
|
||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||
static_cast<float>(
|
||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||
} // c
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||
++c) {
|
||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||
static_cast<float>(
|
||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||
} // c
|
||||
|
||||
} // ih, iw check
|
||||
} // ww
|
||||
} // wh
|
||||
} // ih, iw check
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
} // g
|
||||
};
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
} // g
|
||||
};
|
||||
|
||||
int oH_border_0 = 0;
|
||||
int oH_border_1 =
|
||||
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
|
||||
int oH_border_2 = std::max(
|
||||
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
|
||||
int oH_border_3 = oH;
|
||||
int oH_border_0 = 0;
|
||||
int oH_border_1 = is_idil_one
|
||||
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
|
||||
: oH;
|
||||
int oH_border_2 = std::max(
|
||||
oH_border_1,
|
||||
(iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]);
|
||||
int oH_border_3 = oH;
|
||||
|
||||
int oW_border_0 = 0;
|
||||
int oW_border_1 =
|
||||
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
|
||||
int oW_border_2 = std::max(
|
||||
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
|
||||
int oW_border_3 = oW;
|
||||
int oW_border_0 = 0;
|
||||
int oW_border_1 = is_idil_one
|
||||
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
|
||||
: oW;
|
||||
int oW_border_2 = std::max(
|
||||
oW_border_1,
|
||||
(iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]);
|
||||
int oW_border_3 = oW;
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
// Case 1: oh might put us out of bounds
|
||||
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
for (int n = 0; n < N; ++n) {
|
||||
// Case 1: oh might put us out of bounds
|
||||
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
|
||||
// Case 2: oh in bounds
|
||||
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
|
||||
// Case a: ow might put us out of bounds
|
||||
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||
} // ow
|
||||
// Case 2: oh in bounds
|
||||
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
|
||||
// Case a: ow might put us out of bounds
|
||||
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||
} // ow
|
||||
|
||||
// Case b: ow in bounds
|
||||
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
|
||||
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||
} // ow
|
||||
// Case b: ow in bounds
|
||||
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
|
||||
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||
} // ow
|
||||
|
||||
// Case c: ow might put us out of bounds
|
||||
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||
} // ow
|
||||
// Case c: ow might put us out of bounds
|
||||
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||
} // ow
|
||||
|
||||
} // oh
|
||||
} // oh
|
||||
|
||||
// Case 3: oh might put us out of bounds
|
||||
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
// Case 3: oh might put us out of bounds
|
||||
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
|
||||
st_in_ptr += in_stride_N;
|
||||
st_out_ptr += out_stride_N;
|
||||
st_in_ptr += in_stride_N;
|
||||
st_out_ptr += out_stride_N;
|
||||
|
||||
} // n
|
||||
});
|
||||
} // n
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -351,7 +359,8 @@ void slow_conv_3D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@@ -400,7 +409,8 @@ void slow_conv_3D(
|
||||
out_stride_H = out.strides()[2],
|
||||
out_stride_W = out.strides()[3],
|
||||
out_stride_O = out.strides()[4],
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@@ -415,9 +425,9 @@ void slow_conv_3D(
|
||||
int oh,
|
||||
int ow) {
|
||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||
int id_base = od * wt_strides[0] - padding[0];
|
||||
int ih_base = oh * wt_strides[1] - padding[1];
|
||||
int iw_base = ow * wt_strides[2] - padding[2];
|
||||
int id_base = od * wt_strides[0] - padding_lo[0];
|
||||
int ih_base = oh * wt_strides[1] - padding_lo[1];
|
||||
int iw_base = ow * wt_strides[2] - padding_lo[2];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
@@ -478,7 +488,7 @@ void slow_conv_3D(
|
||||
std::vector<int> base_w(f_out_jump_w);
|
||||
|
||||
for (int i = 0; i < f_out_jump_d; ++i) {
|
||||
int id_loop = i * wt_strides[0] - padding[0] + init_d;
|
||||
int id_loop = i * wt_strides[0] - padding_lo[0] + init_d;
|
||||
|
||||
int wd_base = 0;
|
||||
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
|
||||
@@ -490,7 +500,7 @@ void slow_conv_3D(
|
||||
}
|
||||
|
||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
|
||||
int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h;
|
||||
|
||||
int wh_base = 0;
|
||||
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
|
||||
@@ -502,7 +512,7 @@ void slow_conv_3D(
|
||||
}
|
||||
|
||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
|
||||
int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w;
|
||||
|
||||
int ww_base = 0;
|
||||
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
|
||||
@@ -521,9 +531,9 @@ void slow_conv_3D(
|
||||
int ow) {
|
||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||
|
||||
int id_base = od * wt_strides[0] - padding[0];
|
||||
int ih_base = oh * wt_strides[1] - padding[1];
|
||||
int iw_base = ow * wt_strides[2] - padding[2];
|
||||
int id_base = od * wt_strides[0] - padding_lo[0];
|
||||
int ih_base = oh * wt_strides[1] - padding_lo[1];
|
||||
int iw_base = ow * wt_strides[2] - padding_lo[2];
|
||||
|
||||
int wd_base = base_d[od % f_out_jump_d];
|
||||
int wh_base = base_h[oh % f_out_jump_h];
|
||||
@@ -573,24 +583,30 @@ void slow_conv_3D(
|
||||
};
|
||||
|
||||
int oD_border_0 = 0;
|
||||
int oD_border_1 =
|
||||
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
|
||||
int oD_border_1 = is_idil_one
|
||||
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
|
||||
: oD;
|
||||
int oD_border_2 = std::max(
|
||||
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
||||
oD_border_1,
|
||||
(iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
||||
int oD_border_3 = oD;
|
||||
|
||||
int oH_border_0 = 0;
|
||||
int oH_border_1 =
|
||||
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
|
||||
int oH_border_1 = is_idil_one
|
||||
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
|
||||
: oH;
|
||||
int oH_border_2 = std::max(
|
||||
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
||||
oH_border_1,
|
||||
(iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
||||
int oH_border_3 = oH;
|
||||
|
||||
int oW_border_0 = 0;
|
||||
int oW_border_1 =
|
||||
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
|
||||
int oW_border_1 = is_idil_one
|
||||
? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2])
|
||||
: oW;
|
||||
int oW_border_2 = std::max(
|
||||
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
||||
oW_border_1,
|
||||
(iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
||||
int oW_border_3 = oW;
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
@@ -658,7 +674,8 @@ void dispatch_slow_conv_1D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@@ -669,7 +686,8 @@ void dispatch_slow_conv_1D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@@ -680,7 +698,8 @@ void dispatch_slow_conv_1D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@@ -691,7 +710,8 @@ void dispatch_slow_conv_1D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@@ -707,7 +727,8 @@ void dispatch_slow_conv_2D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@@ -718,7 +739,8 @@ void dispatch_slow_conv_2D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@@ -729,7 +751,8 @@ void dispatch_slow_conv_2D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@@ -740,7 +763,8 @@ void dispatch_slow_conv_2D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@@ -756,7 +780,8 @@ void dispatch_slow_conv_3D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@@ -767,7 +792,8 @@ void dispatch_slow_conv_3D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@@ -778,7 +804,8 @@ void dispatch_slow_conv_3D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@@ -789,7 +816,8 @@ void dispatch_slow_conv_3D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@@ -829,7 +857,8 @@ void explicit_gemm_conv_1D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
Stream stream) {
|
||||
@@ -848,7 +877,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
|
||||
// Pad input
|
||||
Shape padded_shape = {N, iH + 2 * padding[0], C};
|
||||
Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C};
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
@@ -857,7 +886,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = padding[0] * in_padded.strides()[1];
|
||||
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
@@ -971,7 +1000,8 @@ void explicit_gemm_conv_2D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
Stream stream) {
|
||||
@@ -989,7 +1019,11 @@ void explicit_gemm_conv_2D_cpu(
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
|
||||
// Pad input
|
||||
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
||||
Shape padded_shape = {
|
||||
N,
|
||||
iH + padding_lo[0] + padding_hi[0],
|
||||
iW + padding_lo[1] + padding_hi[1],
|
||||
C};
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
@@ -998,8 +1032,8 @@ void explicit_gemm_conv_2D_cpu(
|
||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset =
|
||||
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2];
|
||||
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
|
||||
padding_lo[1] * in_padded.strides()[2];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
@@ -1091,7 +1125,8 @@ void explicit_gemm_conv_ND_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const bool flip,
|
||||
@@ -1114,7 +1149,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
Shape padded_shape(in.shape().size());
|
||||
padded_shape.front() = N;
|
||||
for (size_t i = 0; i < iDim.size(); i++) {
|
||||
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
||||
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
|
||||
}
|
||||
padded_shape.back() = C;
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
@@ -1125,9 +1160,10 @@ void explicit_gemm_conv_ND_cpu(
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = 0;
|
||||
for (size_t i = 0; i < padding.size(); i++) {
|
||||
data_offset += padding[i] * in_padded.strides()[i + 1];
|
||||
for (size_t i = 0; i < padding_lo.size(); i++) {
|
||||
data_offset += padding_lo[i] * in_padded.strides()[i + 1];
|
||||
}
|
||||
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
@@ -1261,7 +1297,8 @@ void conv_1D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@@ -1270,22 +1307,40 @@ void conv_1D_cpu(
|
||||
const int groups = in.shape().back() / wt.shape().back();
|
||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
||||
return explicit_gemm_conv_1D_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, stream);
|
||||
in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream);
|
||||
}
|
||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
|
||||
return explicit_gemm_conv_ND_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
flip,
|
||||
stream);
|
||||
}
|
||||
|
||||
return dispatch_slow_conv_1D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
flip,
|
||||
stream);
|
||||
}
|
||||
|
||||
void conv_2D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@@ -1295,18 +1350,35 @@ void conv_2D_cpu(
|
||||
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
|
||||
in_dilation[1] == 1 && groups == 1) {
|
||||
return explicit_gemm_conv_ND_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
flip,
|
||||
stream);
|
||||
}
|
||||
|
||||
return dispatch_slow_conv_2D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
flip,
|
||||
stream);
|
||||
}
|
||||
|
||||
void conv_3D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@@ -1317,11 +1389,28 @@ void conv_3D_cpu(
|
||||
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
|
||||
groups == 1) {
|
||||
return explicit_gemm_conv_ND_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
flip,
|
||||
stream);
|
||||
}
|
||||
|
||||
return dispatch_slow_conv_3D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
flip,
|
||||
stream);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -1338,7 +1427,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
padding_lo_,
|
||||
padding_hi_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
@@ -1351,7 +1441,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
padding_lo_,
|
||||
padding_hi_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
@@ -1364,7 +1455,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
padding_lo_,
|
||||
padding_hi_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
|
174
mlx/backend/cpu/eig.cpp
Normal file
174
mlx/backend/cpu/eig.cpp
Normal file
@@ -0,0 +1,174 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void eig_impl(
|
||||
array& a,
|
||||
array& vectors,
|
||||
array& values,
|
||||
bool compute_eigenvectors,
|
||||
Stream stream) {
|
||||
using OT = std::complex<T>;
|
||||
auto a_ptr = a.data<T>();
|
||||
auto eig_ptr = values.data<OT>();
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(values);
|
||||
OT* vec_ptr = nullptr;
|
||||
if (compute_eigenvectors) {
|
||||
encoder.set_output_array(vectors);
|
||||
vec_ptr = vectors.data<OT>();
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
vec_ptr,
|
||||
eig_ptr,
|
||||
compute_eigenvectors,
|
||||
N = vectors.shape(-1),
|
||||
size = vectors.size()]() mutable {
|
||||
// Work query
|
||||
char jobr = 'N';
|
||||
char jobl = compute_eigenvectors ? 'V' : 'N';
|
||||
int n_vecs_r = 1;
|
||||
int n_vecs_l = compute_eigenvectors ? N : 1;
|
||||
int lwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
&work,
|
||||
&lwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
}
|
||||
|
||||
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
|
||||
auto vec_tmp_data =
|
||||
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
|
||||
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
||||
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
a_ptr,
|
||||
&N,
|
||||
eig_tmp,
|
||||
eig_tmp + N,
|
||||
vec_tmp,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
&info);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
|
||||
}
|
||||
if (vec_ptr) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
if (eig_ptr[i].imag() != 0) {
|
||||
// This vector and the next are a pair
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vec_ptr[i * N + j] = {
|
||||
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
|
||||
vec_ptr[(i + 1) * N + j] = {
|
||||
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
|
||||
}
|
||||
i += 1;
|
||||
} else {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
|
||||
}
|
||||
}
|
||||
}
|
||||
vec_ptr += N * N;
|
||||
}
|
||||
a_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
});
|
||||
encoder.add_temporary(a);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Eig::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const auto& a = inputs[0];
|
||||
auto& values = outputs[0];
|
||||
|
||||
auto vectors = compute_eigenvectors_
|
||||
? outputs[1]
|
||||
: array(a.shape(), complex64, nullptr, {});
|
||||
|
||||
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
|
||||
copy(
|
||||
a,
|
||||
a_copy,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
|
||||
values.set_data(allocator::malloc(values.nbytes()));
|
||||
|
||||
if (compute_eigenvectors_) {
|
||||
// Set the strides and flags so the eigenvectors
|
||||
// are in the columns of the output
|
||||
auto flags = vectors.flags();
|
||||
auto strides = vectors.strides();
|
||||
auto ndim = a.ndim();
|
||||
std::swap(strides[ndim - 1], strides[ndim - 2]);
|
||||
|
||||
if (a.size() > 1) {
|
||||
flags.row_contiguous = false;
|
||||
if (ndim > 2) {
|
||||
flags.col_contiguous = false;
|
||||
} else {
|
||||
flags.col_contiguous = true;
|
||||
}
|
||||
}
|
||||
vectors.set_data(
|
||||
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
|
||||
}
|
||||
switch (a.dtype()) {
|
||||
case float32:
|
||||
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -12,6 +12,133 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, class Enable = void>
|
||||
struct EighWork {};
|
||||
|
||||
template <typename T>
|
||||
struct EighWork<
|
||||
T,
|
||||
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||
using R = T;
|
||||
|
||||
char jobz;
|
||||
char uplo;
|
||||
int N;
|
||||
int lwork;
|
||||
int liwork;
|
||||
int info;
|
||||
std::vector<array::Data> buffers;
|
||||
|
||||
EighWork(char jobz_, char uplo_, int N_)
|
||||
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
|
||||
T work;
|
||||
int iwork;
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||
}
|
||||
|
||||
void run(T* vectors, T* values) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vectors,
|
||||
&N,
|
||||
values,
|
||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<int*>(buffers[1].buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct EighWork<std::complex<float>> {
|
||||
using T = std::complex<float>;
|
||||
using R = float;
|
||||
|
||||
char jobz;
|
||||
char uplo;
|
||||
int N;
|
||||
int lwork;
|
||||
int lrwork;
|
||||
int liwork;
|
||||
int info;
|
||||
std::vector<array::Data> buffers;
|
||||
|
||||
EighWork(char jobz_, char uplo_, int N_)
|
||||
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
|
||||
T work;
|
||||
R rwork;
|
||||
int iwork;
|
||||
heevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&rwork,
|
||||
&lrwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work.real());
|
||||
lrwork = static_cast<int>(rwork);
|
||||
liwork = iwork;
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||
}
|
||||
|
||||
void run(T* vectors, R* values) {
|
||||
heevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vectors,
|
||||
&N,
|
||||
values,
|
||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
||||
&lrwork,
|
||||
static_cast<int*>(buffers[2].buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
if (jobz == 'V') {
|
||||
// We have pre-transposed the vectors but we also must conjugate them
|
||||
// when they are complex.
|
||||
//
|
||||
// We could vectorize this but it is so fast in comparison to heevd that
|
||||
// it doesn't really matter.
|
||||
for (int i = 0; i < N; i++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
*vectors = std::conj(*vectors);
|
||||
vectors++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void eigh_impl(
|
||||
array& vectors,
|
||||
@@ -19,8 +146,10 @@ void eigh_impl(
|
||||
const std::string& uplo,
|
||||
bool compute_eigenvectors,
|
||||
Stream stream) {
|
||||
using R = typename EighWork<T>::R;
|
||||
|
||||
auto vec_ptr = vectors.data<T>();
|
||||
auto eig_ptr = values.data<T>();
|
||||
auto eig_ptr = values.data<R>();
|
||||
char jobz = compute_eigenvectors ? 'V' : 'N';
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
@@ -33,49 +162,17 @@ void eigh_impl(
|
||||
N = vectors.shape(-1),
|
||||
size = vectors.size()]() mutable {
|
||||
// Work query
|
||||
int lwork = -1;
|
||||
int liwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
}
|
||||
EighWork<T> work(jobz, uplo, N);
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
|
||||
// Work loop
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vec_ptr,
|
||||
&N,
|
||||
eig_ptr,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
work.run(vec_ptr, eig_ptr);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
if (work.info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
<< work.info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
@@ -131,6 +228,10 @@ void Eigh::eval_cpu(
|
||||
eigh_impl<double>(
|
||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
eigh_impl<std::complex<float>>(
|
||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[Eigh::eval_cpu] only supports float32 or float64.");
|
||||
|
@@ -257,15 +257,11 @@ void gather_axis(
|
||||
const array& ind,
|
||||
array& out,
|
||||
const int axis) {
|
||||
auto strides = ind.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = ind.shape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
|
||||
|
||||
strides = src.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
|
||||
auto shape = remove_index(ind.shape(), axis);
|
||||
ContiguousIterator ind_it(
|
||||
shape, remove_index(ind.strides(), axis), src.ndim() - 1);
|
||||
ContiguousIterator src_it(
|
||||
shape, remove_index(src.strides(), axis), src.ndim() - 1);
|
||||
|
||||
auto ind_ptr = ind.data<IdxT>();
|
||||
auto src_ptr = src.data<T>();
|
||||
@@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
template <typename T, typename IdxT, typename OpT>
|
||||
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
||||
auto strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = idx.shape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
|
||||
|
||||
strides = upd.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
|
||||
auto shape = remove_index(idx.shape(), axis);
|
||||
ContiguousIterator idx_it(
|
||||
shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
|
||||
ContiguousIterator upd_it(
|
||||
shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
|
||||
|
||||
auto idx_ptr = idx.data<IdxT>();
|
||||
auto upd_ptr = upd.data<T>();
|
||||
|
@@ -2,14 +2,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// Required for Visual Studio.
|
||||
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
|
||||
#ifdef _MSC_VER
|
||||
#include <complex>
|
||||
#define LAPACK_COMPLEX_CUSTOM
|
||||
#define lapack_complex_float std::complex<float>
|
||||
#define lapack_complex_double std::complex<double>
|
||||
#endif
|
||||
#define lapack_complex_float_real(z) ((z).real())
|
||||
#define lapack_complex_float_imag(z) ((z).imag())
|
||||
#define lapack_complex_double_real(z) ((z).real())
|
||||
#define lapack_complex_double_imag(z) ((z).imag())
|
||||
|
||||
#ifdef MLX_USE_ACCELERATE
|
||||
#include <Accelerate/Accelerate.h>
|
||||
@@ -32,7 +32,7 @@
|
||||
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
|
||||
#define INSTANTIATE_LAPACK_REAL(FUNC) \
|
||||
template <typename T, typename... Args> \
|
||||
void FUNC(Args... args) { \
|
||||
if constexpr (std::is_same_v<T, float>) { \
|
||||
@@ -42,11 +42,24 @@
|
||||
} \
|
||||
}
|
||||
|
||||
INSTANTIATE_LAPACK_TYPES(geqrf)
|
||||
INSTANTIATE_LAPACK_TYPES(orgqr)
|
||||
INSTANTIATE_LAPACK_TYPES(syevd)
|
||||
INSTANTIATE_LAPACK_TYPES(potrf)
|
||||
INSTANTIATE_LAPACK_TYPES(gesvdx)
|
||||
INSTANTIATE_LAPACK_TYPES(getrf)
|
||||
INSTANTIATE_LAPACK_TYPES(getri)
|
||||
INSTANTIATE_LAPACK_TYPES(trtri)
|
||||
INSTANTIATE_LAPACK_REAL(geqrf)
|
||||
INSTANTIATE_LAPACK_REAL(orgqr)
|
||||
INSTANTIATE_LAPACK_REAL(syevd)
|
||||
INSTANTIATE_LAPACK_REAL(geev)
|
||||
INSTANTIATE_LAPACK_REAL(potrf)
|
||||
INSTANTIATE_LAPACK_REAL(gesvdx)
|
||||
INSTANTIATE_LAPACK_REAL(getrf)
|
||||
INSTANTIATE_LAPACK_REAL(getri)
|
||||
INSTANTIATE_LAPACK_REAL(trtri)
|
||||
|
||||
#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \
|
||||
template <typename T, typename... Args> \
|
||||
void FUNC(Args... args) { \
|
||||
if constexpr (std::is_same_v<T, std::complex<float>>) { \
|
||||
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
|
||||
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
|
||||
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
|
||||
} \
|
||||
}
|
||||
|
||||
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
||||
|
@@ -132,6 +132,10 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(
|
||||
"[AddMM::eval_cpu] Currently only supports float32.");
|
||||
}
|
||||
if (out.size() == 0) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
return;
|
||||
}
|
||||
|
||||
// Fill output with C
|
||||
auto& c = inputs[2];
|
||||
@@ -139,7 +143,9 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
? CopyType::Scalar
|
||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
copy(c, out, ctype, stream());
|
||||
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
return;
|
||||
}
|
||||
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
|
||||
}
|
||||
|
||||
|
@@ -13,9 +13,18 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
|
||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||
}
|
||||
|
||||
inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) {
|
||||
auto power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||
}
|
||||
|
||||
template <typename T, int bits>
|
||||
void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||
assert(bits == 3 || bits == 6);
|
||||
static_assert(bits == 3 || bits == 5 || bits == 6);
|
||||
if (bits == 3) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
||||
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
||||
@@ -25,6 +34,16 @@ void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
|
||||
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
|
||||
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
|
||||
} else if (bits == 5) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x1f);
|
||||
w_out[1] = static_cast<T>(((w_in[0] & 0xe0) >> 5) + ((w_in[1] & 0x3) << 3));
|
||||
w_out[2] = static_cast<T>((w_in[1] & 0x7c) >> 2);
|
||||
w_out[3] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0xf) << 1));
|
||||
w_out[4] = static_cast<T>(((w_in[2] & 0xf0) >> 4) + ((w_in[3] & 0x1) << 4));
|
||||
w_out[5] = static_cast<T>((w_in[3] & 0x3e) >> 1);
|
||||
w_out[6] = static_cast<T>(((w_in[3] & 0xc0) >> 6) + ((w_in[4] & 0x7) << 2));
|
||||
w_out[7] = static_cast<T>((w_in[4] & 0xf8) >> 3);
|
||||
|
||||
} else if (bits == 6) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
|
||||
w_out[1] =
|
||||
@@ -46,8 +65,8 @@ void _qmm(
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||
constexpr int pack_factor = get_pack_factor(bits, 8);
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
@@ -65,7 +84,7 @@ void _qmm(
|
||||
T scale = *scales_local++;
|
||||
T bias = *biases_local++;
|
||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||
if (bits == 3 || bits == 6) {
|
||||
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||
T wl[pack_factor];
|
||||
extract_bits<T, bits>(w_local, wl);
|
||||
#pragma clang loop unroll(full)
|
||||
@@ -104,8 +123,9 @@ void _qmm_t(
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||
|
||||
constexpr int pack_factor = get_pack_factor(bits, 8);
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
@@ -121,7 +141,7 @@ void _qmm_t(
|
||||
T bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||
if (bits == 3 || bits == 6) {
|
||||
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||
T wl[pack_factor];
|
||||
extract_bits<T, bits>(w_local, wl);
|
||||
#pragma clang loop unroll(full)
|
||||
@@ -304,6 +324,10 @@ void _qmm_dispatch_typed(
|
||||
_qmm_dispatch_group<T, 4>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 5:
|
||||
_qmm_dispatch_group<T, 5>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 6:
|
||||
_qmm_dispatch_group<T, 6>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
@@ -613,9 +637,8 @@ void quantize(
|
||||
float eps = 1e-7;
|
||||
|
||||
bool power_of_2_bits = is_power_of_2(bits);
|
||||
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
|
||||
int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
int el_per_int = get_pack_factor(bits, 32);
|
||||
int bytes_per_pack = get_bytes_per_pack(bits);
|
||||
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
||||
size_t n_groups = w_size / group_size;
|
||||
|
||||
@@ -640,15 +663,21 @@ void quantize(
|
||||
}
|
||||
size_t out_idx = i * int_per_group;
|
||||
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
|
||||
uint32_t out_el = 0;
|
||||
uint64_t out_el = 0;
|
||||
for (int k = 0; k < el_per_int; ++k) {
|
||||
float w_el = w[w_idx + j * el_per_int + k];
|
||||
w_el = std::rint((w_el - bias) / scale);
|
||||
w_el = std::min(std::max(w_el, 0.0f), n_bins);
|
||||
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
||||
out_el |= static_cast<uint64_t>(w_el) << (k * bits);
|
||||
}
|
||||
if (power_of_2_bits) {
|
||||
out[out_idx + j] = out_el;
|
||||
} else if (bits == 5) {
|
||||
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
|
||||
out[out_idx + bytes_per_pack * j + 3] = (out_el & 0xff000000) >> 24;
|
||||
out[out_idx + bytes_per_pack * j + 4] = (out_el & 0xff00000000) >> 32;
|
||||
} else {
|
||||
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||
|
@@ -330,7 +330,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case complex64:
|
||||
throw std::runtime_error("Scan ops do not support complex types yet");
|
||||
scan_dispatch<complex64_t, complex64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
@@ -88,12 +88,33 @@ DEFAULT_UNARY(expm1, std::expm1)
|
||||
DEFAULT_UNARY(floor, std::floor)
|
||||
DEFAULT_UNARY(log, std::log)
|
||||
DEFAULT_UNARY(log10, std::log10)
|
||||
DEFAULT_UNARY(log1p, std::log1p)
|
||||
DEFAULT_UNARY(sinh, std::sinh)
|
||||
DEFAULT_UNARY(sqrt, std::sqrt)
|
||||
DEFAULT_UNARY(tan, std::tan)
|
||||
DEFAULT_UNARY(tanh, std::tanh)
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> log1p(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
auto x = in.value.real();
|
||||
auto y = in.value.imag();
|
||||
auto zabs = std::abs(in.value);
|
||||
auto theta = std::atan2(y, x + 1);
|
||||
if (zabs < 0.5) {
|
||||
auto r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return Simd<T, 1>{T{x, theta}};
|
||||
}
|
||||
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
|
||||
} else {
|
||||
auto z0 = std::hypot(x + 1, y);
|
||||
return Simd<T, 1>{T{std::log(z0), theta}};
|
||||
}
|
||||
} else {
|
||||
return Simd<T, 1>{std::log1p(in.value)};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> log2(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
|
@@ -2,32 +2,13 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void set_unary_output_data(const array& in, array& out) {
|
||||
if (in.flags().contiguous) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
||||
for (size_t i = 0; i < shape; i += 1) {
|
||||
|
80
mlx/backend/cuda/CMakeLists.txt
Normal file
80
mlx/backend/cuda/CMakeLists.txt
Normal file
@@ -0,0 +1,80 @@
|
||||
# Filename rules in cuda backend:
|
||||
#
|
||||
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
|
||||
# * Device-only kernel code should be put in kernels/ subdir.
|
||||
# * Files in kernels/ subdir should not include files outside.
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||
|
||||
# Enable defining device lambda functions.
|
||||
target_compile_options(mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||
|
||||
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
|
||||
# Explicitly pass this flag to suppress the warning, it is safe to set it to
|
||||
# true but the warning wouldn't be suppressed.
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
|
||||
target_compile_options(
|
||||
mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--static-global-template-stub=false>")
|
||||
endif()
|
||||
|
||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||
set(MLX_CUDA_ARCHITECTURES
|
||||
"70;80"
|
||||
CACHE STRING "CUDA architectures")
|
||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||
"${MLX_CUDA_ARCHITECTURES}")
|
||||
|
||||
# Use fixed version of CCCL.
|
||||
FetchContent_Declare(
|
||||
cccl
|
||||
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
|
||||
FetchContent_MakeAvailable(cccl)
|
||||
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
|
||||
|
||||
# Use fixed version of NVTX.
|
||||
FetchContent_Declare(
|
||||
nvtx3
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git
|
||||
GIT_TAG v3.1.1
|
||||
GIT_SHALLOW TRUE
|
||||
SOURCE_SUBDIR c EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(nvtx3)
|
||||
target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
|
||||
|
||||
# Make cuda runtime APIs available in non-cuda files.
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
||||
|
||||
# Use cublasLt.
|
||||
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
||||
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
206
mlx/backend/cuda/allocator.cpp
Normal file
206
mlx/backend/cuda/allocator.cpp
Normal file
@@ -0,0 +1,206 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <fmt/format.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
CudaAllocator::CudaAllocator()
|
||||
: buffer_cache_(
|
||||
getpagesize(),
|
||||
[](CudaBuffer* buf) { return buf->size; },
|
||||
[this](CudaBuffer* buf) {
|
||||
cuda_free(buf->data);
|
||||
delete buf;
|
||||
}) {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.8;
|
||||
max_pool_size_ = memory_limit_;
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
// Find available buffer from cache.
|
||||
std::unique_lock lock(mutex_);
|
||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
if (!buf) {
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
// try to reclaim memory from the cache.
|
||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
||||
if (mem_required >= memory_limit_) {
|
||||
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
|
||||
}
|
||||
|
||||
lock.unlock();
|
||||
buf = new CudaBuffer{nullptr, size};
|
||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
}
|
||||
lock.lock();
|
||||
}
|
||||
active_memory_ += size;
|
||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||
|
||||
// Maintain the cache below the requested limit.
|
||||
if (get_cache_memory() > max_pool_size_) {
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
return Buffer{buf};
|
||||
}
|
||||
|
||||
void CudaAllocator::free(Buffer buffer) {
|
||||
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_lock lock(mutex_);
|
||||
active_memory_ -= buf->size;
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
lock.unlock();
|
||||
cuda_free(buf->data);
|
||||
delete buf;
|
||||
}
|
||||
}
|
||||
|
||||
size_t CudaAllocator::size(Buffer buffer) const {
|
||||
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return 0;
|
||||
}
|
||||
return buf->size;
|
||||
}
|
||||
|
||||
void CudaAllocator::register_this_thread() {
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
allowed_threads_.insert(std::this_thread::get_id());
|
||||
}
|
||||
|
||||
void CudaAllocator::cuda_free(void* buf) {
|
||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
}
|
||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
cudaFree(buf);
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_active_memory() const {
|
||||
return active_memory_;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_peak_memory() const {
|
||||
return peak_memory_;
|
||||
}
|
||||
|
||||
void CudaAllocator::reset_peak_memory() {
|
||||
std::lock_guard lock(mutex_);
|
||||
peak_memory_ = 0;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_memory_limit() {
|
||||
return memory_limit_;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::set_memory_limit(size_t limit) {
|
||||
std::lock_guard lock(mutex_);
|
||||
std::swap(limit, memory_limit_);
|
||||
return limit;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_cache_memory() const {
|
||||
return buffer_cache_.cache_size();
|
||||
}
|
||||
|
||||
size_t CudaAllocator::set_cache_limit(size_t limit) {
|
||||
std::lock_guard lk(mutex_);
|
||||
std::swap(limit, max_pool_size_);
|
||||
return limit;
|
||||
}
|
||||
|
||||
void CudaAllocator::clear_cache() {
|
||||
std::lock_guard lk(mutex_);
|
||||
buffer_cache_.clear();
|
||||
}
|
||||
|
||||
CudaAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
||||
// will not be called on exit and buffers in the cache will be leaked. This
|
||||
// can save some time at program exit.
|
||||
static CudaAllocator* allocator_ = new CudaAllocator;
|
||||
return *allocator_;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace allocator {
|
||||
|
||||
Allocator& allocator() {
|
||||
return cu::allocator();
|
||||
}
|
||||
|
||||
void* Buffer::raw_ptr() {
|
||||
if (!ptr_) {
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<cu::CudaBuffer*>(ptr_)->data;
|
||||
}
|
||||
|
||||
} // namespace allocator
|
||||
|
||||
size_t get_active_memory() {
|
||||
return cu::allocator().get_active_memory();
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return cu::allocator().get_peak_memory();
|
||||
}
|
||||
void reset_peak_memory() {
|
||||
return cu::allocator().reset_peak_memory();
|
||||
}
|
||||
size_t set_memory_limit(size_t limit) {
|
||||
return cu::allocator().set_memory_limit(limit);
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return cu::allocator().get_memory_limit();
|
||||
}
|
||||
size_t get_cache_memory() {
|
||||
return cu::allocator().get_cache_memory();
|
||||
}
|
||||
size_t set_cache_limit(size_t limit) {
|
||||
return cu::allocator().set_cache_limit(limit);
|
||||
}
|
||||
void clear_cache() {
|
||||
cu::allocator().clear_cache();
|
||||
}
|
||||
|
||||
// Not supported in CUDA.
|
||||
size_t set_wired_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
67
mlx/backend/cuda/allocator.h
Normal file
67
mlx/backend/cuda/allocator.h
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/buffer_cache.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Worker;
|
||||
|
||||
using allocator::Buffer;
|
||||
|
||||
// Stores cuda-managed unified memory.
|
||||
struct CudaBuffer {
|
||||
void* data;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
class CudaAllocator : public allocator::Allocator {
|
||||
public:
|
||||
Buffer malloc(size_t size) override;
|
||||
void free(Buffer buffer) override;
|
||||
size_t size(Buffer buffer) const override;
|
||||
|
||||
// Register current thread as safe to free buffers.
|
||||
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
|
||||
// that may be waited by gpu stream (for example cpu stream threads), freeing
|
||||
// buffers there would result in dead lock.
|
||||
void register_this_thread();
|
||||
|
||||
// Call cudaFree in the safe thread.
|
||||
void cuda_free(void* buf);
|
||||
|
||||
size_t get_active_memory() const;
|
||||
size_t get_peak_memory() const;
|
||||
void reset_peak_memory();
|
||||
size_t get_memory_limit();
|
||||
size_t set_memory_limit(size_t limit);
|
||||
size_t get_cache_memory() const;
|
||||
size_t set_cache_limit(size_t limit);
|
||||
void clear_cache();
|
||||
|
||||
private:
|
||||
CudaAllocator();
|
||||
friend CudaAllocator& allocator();
|
||||
|
||||
std::mutex worker_mutex_;
|
||||
std::unique_ptr<Worker> worker_;
|
||||
std::set<std::thread::id> allowed_threads_;
|
||||
|
||||
std::mutex mutex_;
|
||||
size_t memory_limit_;
|
||||
size_t max_pool_size_;
|
||||
BufferCache<CudaBuffer> buffer_cache_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
};
|
||||
|
||||
CudaAllocator& allocator();
|
||||
|
||||
} // namespace mlx::core::cu
|
305
mlx/backend/cuda/binary.cu
Normal file
305
mlx/backend/cuda/binary.cu
Normal file
@@ -0,0 +1,305 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[0], b[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[0], b[index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[index], b[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[index], b[index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void binary_g_nd(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), a_strides.data(), b_strides.data());
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_g(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides a_strides,
|
||||
const __grid_constant__ Strides b_strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_binary_op() {
|
||||
if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||
|
||||
std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||
|
||||
std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||
|
||||
std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {
|
||||
return std::is_same_v<In, Out>;
|
||||
}
|
||||
if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||
|
||||
std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||
|
||||
std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {
|
||||
return std::is_same_v<Out, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {
|
||||
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, NaNEqual>) {
|
||||
return std::is_same_v<Out, bool> &&
|
||||
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
|
||||
}
|
||||
if (std::is_same_v<Op, LogAddExp> || std::is_same_v<Op, ArcTan2>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
|
||||
std::is_same_v<Op, BitwiseXor>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||
!std::is_same_v<In, bool>;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
bool large = a.data_size() > UINT32_MAX ||
|
||||
b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel =
|
||||
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(a_strides),
|
||||
const_param<NDIM>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size());
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||
op,
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
std::vector<array> outputs{out};
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
#define BINARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = out.primitive().stream(); \
|
||||
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
#define BINARY_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = outputs[0].primitive().stream(); \
|
||||
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
BINARY_GPU(Remainder)
|
||||
BINARY_GPU(Equal)
|
||||
BINARY_GPU(Greater)
|
||||
BINARY_GPU(GreaterEqual)
|
||||
BINARY_GPU(Less)
|
||||
BINARY_GPU(LessEqual)
|
||||
BINARY_GPU(LogicalAnd)
|
||||
BINARY_GPU(LogicalOr)
|
||||
BINARY_GPU(LogAddExp)
|
||||
BINARY_GPU(Maximum)
|
||||
BINARY_GPU(Minimum)
|
||||
BINARY_GPU(Multiply)
|
||||
BINARY_GPU(NotEqual)
|
||||
BINARY_GPU(Power)
|
||||
BINARY_GPU(Subtract)
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op_gpu<cu::LeftShift>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op_gpu<cu::RightShift>(inputs, out, op, s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
89
mlx/backend/cuda/copy.cu
Normal file
89
mlx/backend/cuda/copy.cu
Normal file
@@ -0,0 +1,89 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in_,
|
||||
array& out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
CopyType ctype,
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_offset_in,
|
||||
const std::optional<array>& dynamic_offset_out) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
const array& in = in_.data_shared_ptr() ? in_ : out;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
||||
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
||||
return;
|
||||
}
|
||||
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
auto [shape_collapsed, strides_vec] = collapse_contiguous_dims(
|
||||
shape, std::vector{strides_in, strides_out}, INT32_MAX);
|
||||
if (ctype == CopyType::General) {
|
||||
copy_general_input(
|
||||
encoder,
|
||||
ctype,
|
||||
in,
|
||||
out,
|
||||
offset_in,
|
||||
offset_out,
|
||||
shape_collapsed,
|
||||
strides_vec[0]);
|
||||
} else {
|
||||
if (dynamic_offset_in || dynamic_offset_out) {
|
||||
copy_general_dynamic(
|
||||
encoder,
|
||||
ctype,
|
||||
in,
|
||||
out,
|
||||
offset_in,
|
||||
offset_out,
|
||||
shape_collapsed,
|
||||
strides_vec[0],
|
||||
strides_vec[1],
|
||||
dynamic_offset_in ? *dynamic_offset_in : array(0, int64),
|
||||
dynamic_offset_out ? *dynamic_offset_out : array(0, int64));
|
||||
} else {
|
||||
copy_general(
|
||||
encoder,
|
||||
ctype,
|
||||
in,
|
||||
out,
|
||||
offset_in,
|
||||
offset_out,
|
||||
shape_collapsed,
|
||||
strides_vec[0],
|
||||
strides_vec[1]);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void fill_gpu(const array& in, array& out, const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
71
mlx/backend/cuda/copy/copy.cuh
Normal file
71
mlx/backend/cuda/copy/copy.cuh
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
||||
using InType = cuda_type_t<CTYPE_IN>; \
|
||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
||||
if constexpr (cu::CastOp<InType, OutType>::is_castable) { \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
throw std::runtime_error(fmt::format( \
|
||||
"Can not copy data from dtype {} to {}.", \
|
||||
dtype_to_string(out.dtype()), \
|
||||
dtype_to_string(in.dtype()))); \
|
||||
} \
|
||||
}); \
|
||||
})
|
||||
|
||||
void copy_contiguous(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out);
|
||||
|
||||
void copy_general(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out);
|
||||
|
||||
void copy_general_dynamic(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out,
|
||||
const array& dynamic_offset_in,
|
||||
const array& dynamic_offset_out);
|
||||
|
||||
void copy_general_input(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in);
|
||||
|
||||
} // namespace mlx::core
|
56
mlx/backend/cuda/copy/copy_contiguous.cu
Normal file
56
mlx/backend/cuda/copy/copy_contiguous.cu
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
__global__ void copy_s(const In* in, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = CastOp<In, Out>{}(in[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
__global__ void copy_v(const In* in, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = CastOp<In, Out>{}(in[index]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void copy_contiguous(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t in_offset,
|
||||
int64_t out_offset) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = cu::copy_s<InType, OutType, IdxT>;
|
||||
if (ctype == CopyType::Vector) {
|
||||
kernel = cu::copy_v<InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>() + in_offset,
|
||||
out.data<OutType>() + out_offset,
|
||||
out.data_size());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
95
mlx/backend/cuda/copy/copy_general.cu
Normal file
95
mlx/backend/cuda/copy/copy_general.cu
Normal file
@@ -0,0 +1,95 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void copy_gg_nd(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), strides_in.data(), strides_out.data());
|
||||
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
__global__ void copy_gg(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides_in,
|
||||
const __grid_constant__ Strides strides_out,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc_4d(
|
||||
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void copy_general(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in),
|
||||
const_param<NDIM>(strides_out));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
105
mlx/backend/cuda/copy/copy_general_dynamic.cu
Normal file
105
mlx/backend/cuda/copy/copy_general_dynamic.cu
Normal file
@@ -0,0 +1,105 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void copy_gg_dynamic_nd(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out,
|
||||
const int64_t* offset_in,
|
||||
const int64_t* offset_out) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), strides_in.data(), strides_out.data());
|
||||
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
__global__ void copy_gg_dynamic(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides_in,
|
||||
const __grid_constant__ Strides strides_out,
|
||||
int ndim,
|
||||
const int64_t* offset_in,
|
||||
const int64_t* offset_out) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc_4d(
|
||||
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void copy_general_dynamic(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out,
|
||||
const array& dynamic_offset_in,
|
||||
const array& dynamic_offset_out) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel = cu::copy_gg_dynamic_nd<InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in),
|
||||
const_param<NDIM>(strides_out),
|
||||
dynamic_offset_in.data<int64_t>(),
|
||||
dynamic_offset_out.data<int64_t>());
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
ndim,
|
||||
dynamic_offset_in.data<int64_t>(),
|
||||
dynamic_offset_out.data<int64_t>());
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
88
mlx/backend/cuda/copy/copy_general_input.cu
Normal file
88
mlx/backend/cuda/copy/copy_general_input.cu
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void copy_g_nd(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data());
|
||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
__global__ void copy_g(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides_in,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim);
|
||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void copy_general_input(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel = cu::copy_g_nd<InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_g<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
11
mlx/backend/cuda/cuda.cpp
Normal file
11
mlx/backend/cuda/cuda.cpp
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
10
mlx/backend/cuda/cuda.h
Normal file
10
mlx/backend/cuda/cuda.h
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
/* Check if the CUDA backend is available. */
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::cu
|
129
mlx/backend/cuda/device.cpp
Normal file
129
mlx/backend/cuda/device.cpp
Normal file
@@ -0,0 +1,129 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
|
||||
|
||||
void DeviceStream::synchronize() {
|
||||
cudaStreamSynchronize(stream_);
|
||||
}
|
||||
|
||||
cudaStream_t DeviceStream::schedule_cuda_stream() {
|
||||
// TODO: Return a stream that maximizes parallelism.
|
||||
return stream_;
|
||||
}
|
||||
|
||||
cudaStream_t DeviceStream::last_cuda_stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
CommandEncoder& DeviceStream::get_encoder() {
|
||||
if (!encoder_) {
|
||||
encoder_ = std::make_unique<CommandEncoder>(*this);
|
||||
}
|
||||
return *encoder_;
|
||||
}
|
||||
|
||||
Device::Device(int device) : device_(device) {
|
||||
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||
&compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_));
|
||||
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||
&compute_capability_minor_, cudaDevAttrComputeCapabilityMinor, device_));
|
||||
// Validate the requirements of device.
|
||||
int attr = 0;
|
||||
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||
&attr, cudaDevAttrConcurrentManagedAccess, device_));
|
||||
if (attr != 1) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Device {} does not support synchronization in managed memory.",
|
||||
device_));
|
||||
}
|
||||
// The cublasLt handle is used by matmul.
|
||||
make_current();
|
||||
cublasLtCreate(<_);
|
||||
}
|
||||
|
||||
Device::~Device() {
|
||||
cublasLtDestroy(lt_);
|
||||
}
|
||||
|
||||
void Device::make_current() {
|
||||
// We need to set/get current CUDA device very frequently, cache it to reduce
|
||||
// actual calls of CUDA APIs. This function assumes single-thread in host.
|
||||
static int current = 0;
|
||||
if (current != device_) {
|
||||
CHECK_CUDA_ERROR(cudaSetDevice(device_));
|
||||
current = device_;
|
||||
}
|
||||
}
|
||||
|
||||
DeviceStream& Device::get_stream(Stream s) {
|
||||
auto it = streams_.find(s.index);
|
||||
if (it == streams_.end()) {
|
||||
it = streams_.try_emplace(s.index, *this).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(DeviceStream& s)
|
||||
: device_(s.device()), stream_(s) {}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
worker_.add_task(std::move(task));
|
||||
}
|
||||
|
||||
void CommandEncoder::end_encoding() {
|
||||
if (!temporaries_.empty()) {
|
||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||
}
|
||||
|
||||
// There is no kernel running, run completion handlers immediately.
|
||||
if (!has_gpu_work_) {
|
||||
worker_.consume_in_this_thread();
|
||||
return;
|
||||
}
|
||||
has_gpu_work_ = false;
|
||||
|
||||
// Put completion handlers in a batch.
|
||||
worker_.end_batch();
|
||||
|
||||
// Signaling kernel completion is expensive, delay until enough batches.
|
||||
// TODO: This number is arbitrarily picked, profile for a better stragety.
|
||||
if (worker_.uncommited_batches() > 8) {
|
||||
commit();
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
worker_.commit(stream_.last_cuda_stream());
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device device) {
|
||||
static std::unordered_map<int, Device> devices;
|
||||
auto it = devices.find(device.index);
|
||||
if (it == devices.end()) {
|
||||
it = devices.try_emplace(device.index, device.index).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
DeviceStream& get_stream(Stream s) {
|
||||
return device(s.device).get_stream(s);
|
||||
}
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream s) {
|
||||
return get_stream(s).get_encoder();
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
} // namespace mlx::core
|
145
mlx/backend/cuda/device.h
Normal file
145
mlx/backend/cuda/device.h
Normal file
@@ -0,0 +1,145 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Device;
|
||||
class CommandEncoder;
|
||||
|
||||
class DeviceStream {
|
||||
public:
|
||||
explicit DeviceStream(Device& device);
|
||||
|
||||
DeviceStream(const DeviceStream&) = delete;
|
||||
DeviceStream& operator=(const DeviceStream&) = delete;
|
||||
|
||||
// Wait until kernels in the stream complete.
|
||||
void synchronize();
|
||||
|
||||
// Return a cuda stream for launching kernels.
|
||||
cudaStream_t schedule_cuda_stream();
|
||||
|
||||
// Return the last cuda stream used.
|
||||
cudaStream_t last_cuda_stream();
|
||||
|
||||
CommandEncoder& get_encoder();
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
CudaStream stream_;
|
||||
std::unique_ptr<CommandEncoder> encoder_;
|
||||
};
|
||||
|
||||
class Device {
|
||||
public:
|
||||
explicit Device(int device);
|
||||
~Device();
|
||||
|
||||
Device(const Device&) = delete;
|
||||
Device& operator=(const Device&) = delete;
|
||||
|
||||
// Make this device the current cuda device, required by some cuda calls.
|
||||
void make_current();
|
||||
|
||||
DeviceStream& get_stream(Stream s);
|
||||
|
||||
int cuda_device() const {
|
||||
return device_;
|
||||
}
|
||||
int compute_capability_major() const {
|
||||
return compute_capability_major_;
|
||||
}
|
||||
int compute_capability_minor() const {
|
||||
return compute_capability_minor_;
|
||||
}
|
||||
cublasLtHandle_t lt_handle() const {
|
||||
return lt_;
|
||||
}
|
||||
|
||||
private:
|
||||
int device_;
|
||||
int compute_capability_major_;
|
||||
int compute_capability_minor_;
|
||||
cublasLtHandle_t lt_;
|
||||
std::unordered_map<int, DeviceStream> streams_;
|
||||
};
|
||||
|
||||
class CommandEncoder {
|
||||
public:
|
||||
explicit CommandEncoder(DeviceStream& stream);
|
||||
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
void set_input_array(const array& arr) {}
|
||||
void set_output_array(const array& arr) {}
|
||||
|
||||
void add_temporary(const array& arr) {
|
||||
temporaries_.push_back(arr.data_shared_ptr());
|
||||
}
|
||||
|
||||
void add_completed_handler(std::function<void()> task);
|
||||
void end_encoding();
|
||||
void commit();
|
||||
|
||||
// Schedule a cuda stream for |fun| to launch kernels, and check error
|
||||
// afterwards.
|
||||
template <typename F>
|
||||
void launch_kernel(F&& fun) {
|
||||
launch_kernel(stream_.schedule_cuda_stream(), std::forward<F>(fun));
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void launch_kernel(cudaStream_t stream, F&& fun) {
|
||||
device_.make_current();
|
||||
fun(stream);
|
||||
check_cuda_error("kernel launch", cudaGetLastError());
|
||||
has_gpu_work_ = true;
|
||||
}
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
DeviceStream& stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
bool has_gpu_work() const {
|
||||
return has_gpu_work_;
|
||||
}
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
DeviceStream& stream_;
|
||||
Worker worker_;
|
||||
bool has_gpu_work_{false};
|
||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device device);
|
||||
DeviceStream& get_stream(Stream s);
|
||||
CommandEncoder& get_command_encoder(Stream s);
|
||||
|
||||
// Return an execution policy that does not sync for result.
|
||||
// Note that not all thrust APIs support async policy, confirm before using.
|
||||
inline auto thrust_policy(cudaStream_t stream) {
|
||||
// TODO: Connect thrust's custom allocator with mlx's allocator.
|
||||
return thrust::cuda::par_nosync.on(stream);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
68
mlx/backend/cuda/eval.cpp
Normal file
68
mlx/backend/cuda/eval.cpp
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/eval.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/gpu/available.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void new_stream(Stream s) {
|
||||
// Force initalization of cuda, so cuda runtime get destroyed at last.
|
||||
cudaFree(nullptr);
|
||||
// Ensure the static stream objects get created.
|
||||
cu::get_command_encoder(s);
|
||||
// The main thread is safe to free buffers.
|
||||
cu::allocator().register_this_thread();
|
||||
}
|
||||
|
||||
void eval(array& arr) {
|
||||
nvtx3::scoped_range r("gpu::eval");
|
||||
auto outputs = arr.outputs();
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
||||
if (encoder.has_gpu_work()) {
|
||||
// Keep used buffers alive until kernel finishes running.
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input.
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
|
||||
}
|
||||
encoder.end_encoding();
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
nvtx3::scoped_range r("gpu::finalize");
|
||||
cu::get_command_encoder(s).commit();
|
||||
}
|
||||
|
||||
void synchronize(Stream s) {
|
||||
nvtx3::scoped_range r("gpu::synchronize");
|
||||
cu::get_stream(s).synchronize();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::gpu
|
269
mlx/backend/cuda/event.cu
Normal file
269
mlx/backend/cuda/event.cu
Normal file
@@ -0,0 +1,269 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CudaEvent implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Cuda event managed with RAII.
|
||||
class CudaEventHandle {
|
||||
public:
|
||||
CudaEventHandle() {
|
||||
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(
|
||||
&event_, cudaEventDisableTiming | cudaEventBlockingSync));
|
||||
}
|
||||
|
||||
~CudaEventHandle() {
|
||||
CHECK_CUDA_ERROR(cudaEventDestroy(event_));
|
||||
}
|
||||
|
||||
CudaEventHandle(const CudaEventHandle&) = delete;
|
||||
CudaEventHandle& operator=(const CudaEventHandle&) = delete;
|
||||
|
||||
operator cudaEvent_t() const {
|
||||
return event_;
|
||||
}
|
||||
|
||||
private:
|
||||
cudaEvent_t event_;
|
||||
};
|
||||
|
||||
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {}
|
||||
|
||||
void CudaEvent::wait() {
|
||||
nvtx3::scoped_range r("cu::CudaEvent::wait");
|
||||
if (!recorded_) {
|
||||
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
||||
}
|
||||
cudaEventSynchronize(*event_);
|
||||
}
|
||||
|
||||
void CudaEvent::wait(cudaStream_t stream) {
|
||||
if (!recorded_) {
|
||||
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
||||
}
|
||||
cudaStreamWaitEvent(stream, *event_);
|
||||
}
|
||||
|
||||
void CudaEvent::wait(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this]() mutable { wait(); });
|
||||
} else {
|
||||
wait(cu::get_stream(s).last_cuda_stream());
|
||||
}
|
||||
}
|
||||
|
||||
void CudaEvent::record(cudaStream_t stream) {
|
||||
cudaEventRecord(*event_, stream);
|
||||
recorded_ = true;
|
||||
}
|
||||
|
||||
void CudaEvent::record(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
|
||||
} else {
|
||||
record(cu::get_stream(s).last_cuda_stream());
|
||||
}
|
||||
}
|
||||
|
||||
bool CudaEvent::completed() const {
|
||||
return cudaEventQuery(*event_) == cudaSuccess;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// SharedEvent implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
uint64_t current;
|
||||
while ((current = ac->load()) < value) {
|
||||
ac->wait(current);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
ac->store(value);
|
||||
ac->notify_all();
|
||||
}
|
||||
|
||||
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
event_wait(ac, value);
|
||||
}
|
||||
|
||||
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
event_signal(ac, value);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SharedEvent::SharedEvent() {
|
||||
// Allocate cuda::atomic on managed memory.
|
||||
Atomic* ac;
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
|
||||
new (ac) Atomic(0);
|
||||
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
|
||||
ptr->~Atomic();
|
||||
allocator().cuda_free(ptr);
|
||||
});
|
||||
}
|
||||
|
||||
void SharedEvent::wait(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
||||
event_wait(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
||||
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::wait(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::wait(s)");
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
|
||||
} else {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
encoder.stream().last_cuda_stream(),
|
||||
[this, value](cudaStream_t stream) { wait(stream, value); });
|
||||
encoder.add_completed_handler([ac = ac_]() {});
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
void SharedEvent::signal(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
||||
event_signal(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
||||
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
// Signal through a GPU stream so the atomic is updated in GPU - updating
|
||||
// the atomic in CPU sometimes does not get GPU notified.
|
||||
static CudaStream stream(device(mlx::core::Device::gpu));
|
||||
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
|
||||
} else {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
encoder.stream().last_cuda_stream(),
|
||||
[this, value](cudaStream_t stream) { signal(stream, value); });
|
||||
encoder.add_completed_handler([ac = ac_]() {});
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
bool SharedEvent::is_signaled(uint64_t value) const {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
||||
return ac_->load() >= value;
|
||||
}
|
||||
|
||||
uint64_t SharedEvent::value() const {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::value");
|
||||
return ac_->load();
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Event implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
struct EventImpl {
|
||||
// CudaEvent is preferred when possible because it is fast, however we have
|
||||
// to fallback to SharedEvent in following cases:
|
||||
// 1. the event is used to wait/signal a cpu stream;
|
||||
// 2. signal value other than 1 has been specified.
|
||||
std::unique_ptr<cu::CudaEvent> cuda;
|
||||
std::unique_ptr<cu::SharedEvent> shared;
|
||||
|
||||
bool is_created() const {
|
||||
return cuda || shared;
|
||||
}
|
||||
|
||||
void ensure_created(Stream s, uint64_t signal_value) {
|
||||
if (is_created()) {
|
||||
return;
|
||||
}
|
||||
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
|
||||
nvtx3::mark("Using slow SharedEvent");
|
||||
shared = std::make_unique<cu::SharedEvent>();
|
||||
} else {
|
||||
cuda = std::make_unique<cu::CudaEvent>();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Event::Event(Stream s) : stream_(s) {
|
||||
event_ = std::shared_ptr<void>(
|
||||
new EventImpl(), [](void* ptr) { delete static_cast<EventImpl*>(ptr); });
|
||||
}
|
||||
|
||||
void Event::wait() {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
assert(event->is_created());
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
event->cuda->wait();
|
||||
} else {
|
||||
event->shared->wait(value());
|
||||
}
|
||||
}
|
||||
|
||||
void Event::wait(Stream s) {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
assert(event->is_created());
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
event->cuda->wait(s);
|
||||
} else {
|
||||
event->shared->wait(s, value());
|
||||
}
|
||||
}
|
||||
|
||||
void Event::signal(Stream s) {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
event->ensure_created(s, value());
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
event->cuda->record(s);
|
||||
} else {
|
||||
event->shared->signal(s, value());
|
||||
}
|
||||
}
|
||||
|
||||
bool Event::is_signaled() const {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
if (!event->is_created()) {
|
||||
return false;
|
||||
}
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
return event->cuda->recorded() && event->cuda->completed();
|
||||
} else {
|
||||
return event->shared->is_signaled(value());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
66
mlx/backend/cuda/event.h
Normal file
66
mlx/backend/cuda/event.h
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda/atomic>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class CudaEventHandle;
|
||||
|
||||
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait
|
||||
// on GPU stream in CPU stream, but can not wait on CPU stream.
|
||||
class CudaEvent {
|
||||
public:
|
||||
CudaEvent();
|
||||
|
||||
void wait();
|
||||
void wait(cudaStream_t stream);
|
||||
void wait(Stream s);
|
||||
void record(cudaStream_t stream);
|
||||
void record(Stream s);
|
||||
|
||||
// Return whether the recorded kernels have completed. Note that this method
|
||||
// returns true if record() has not been called.
|
||||
bool completed() const;
|
||||
|
||||
bool recorded() const {
|
||||
return recorded_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool recorded_{false};
|
||||
std::shared_ptr<CudaEventHandle> event_;
|
||||
};
|
||||
|
||||
// Event that can synchronize between CPU and GPU. It is much slower than
|
||||
// CudaEvent so the latter should always be preferred when possible.
|
||||
class SharedEvent {
|
||||
public:
|
||||
using Atomic = cuda::atomic<uint64_t>;
|
||||
|
||||
SharedEvent();
|
||||
|
||||
void wait(uint64_t value);
|
||||
void wait(cudaStream_t stream, uint64_t value);
|
||||
void wait(Stream s, uint64_t value);
|
||||
void signal(uint64_t value);
|
||||
void signal(cudaStream_t stream, uint64_t value);
|
||||
void signal(Stream s, uint64_t value);
|
||||
bool is_signaled(uint64_t value) const;
|
||||
uint64_t value() const;
|
||||
|
||||
const std::shared_ptr<Atomic>& atomic() const {
|
||||
return ac_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Atomic> ac_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
29
mlx/backend/cuda/fence.cpp
Normal file
29
mlx/backend/cuda/fence.cpp
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
struct FenceImpl {
|
||||
uint32_t count;
|
||||
cu::SharedEvent event;
|
||||
};
|
||||
|
||||
Fence::Fence(Stream s) {
|
||||
fence_ = std::shared_ptr<void>(
|
||||
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
|
||||
}
|
||||
|
||||
void Fence::wait(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
fence->event.wait(fence->count);
|
||||
}
|
||||
|
||||
void Fence::update(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
fence->count++;
|
||||
fence->event.signal(s, fence->count);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
121
mlx/backend/cuda/iterators/general_iterator.cuh
Normal file
121
mlx/backend/cuda/iterators/general_iterator.cuh
Normal file
@@ -0,0 +1,121 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <thrust/iterator/iterator_adaptor.h>
|
||||
#include <cuda/std/utility>
|
||||
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Iterating non-contiguous array.
|
||||
template <typename Iterator, typename IdxT = int64_t>
|
||||
class general_iterator
|
||||
: public thrust::
|
||||
iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator> {
|
||||
public:
|
||||
using super_t =
|
||||
thrust::iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator>;
|
||||
|
||||
using reference = typename super_t::reference;
|
||||
using difference_type = typename super_t::difference_type;
|
||||
|
||||
__host__ __device__ general_iterator(
|
||||
Iterator it,
|
||||
IdxT index,
|
||||
int ndim,
|
||||
Shape shape,
|
||||
Strides strides)
|
||||
: super_t(it),
|
||||
index_(index),
|
||||
ndim_(ndim),
|
||||
shape_(cuda::std::move(shape)),
|
||||
strides_(cuda::std::move(strides)) {}
|
||||
|
||||
__host__ __device__ IdxT index() const {
|
||||
return index_;
|
||||
}
|
||||
|
||||
__host__ __device__ const Shape& shape() const {
|
||||
return shape_;
|
||||
}
|
||||
|
||||
__host__ __device__ const Strides& strides() const {
|
||||
return strides_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class thrust::iterator_core_access;
|
||||
|
||||
__host__ __device__ bool equal(const general_iterator& other) const {
|
||||
return this->base() == other.base() && this->index() == other.index();
|
||||
}
|
||||
|
||||
__host__ __device__ void advance(difference_type n) {
|
||||
this->index_ += n;
|
||||
}
|
||||
|
||||
__host__ __device__ void increment() {
|
||||
this->index_ += 1;
|
||||
}
|
||||
|
||||
__host__ __device__ void decrement() {
|
||||
this->index_ -= 1;
|
||||
}
|
||||
|
||||
__host__ __device__ difference_type
|
||||
distance_to(const general_iterator& other) const {
|
||||
_CCCL_ASSERT(
|
||||
this->base() == other.base(),
|
||||
"Underlying iterator must point to same base iterator");
|
||||
return other.index() - this->index();
|
||||
}
|
||||
|
||||
// The dereference is device-only to avoid accidental running in host.
|
||||
__device__ typename super_t::reference dereference() const {
|
||||
IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_);
|
||||
return *(this->base() + offset);
|
||||
}
|
||||
|
||||
IdxT index_;
|
||||
int ndim_;
|
||||
Shape shape_;
|
||||
Strides strides_;
|
||||
};
|
||||
|
||||
template <typename IdxT, typename Iterator>
|
||||
__host__ __device__ auto make_general_iterator(
|
||||
Iterator it,
|
||||
IdxT index,
|
||||
int ndim,
|
||||
Shape shape,
|
||||
Strides strides) {
|
||||
return general_iterator<Iterator, IdxT>(
|
||||
it, index, ndim, cuda::std::move(shape), cuda::std::move(strides));
|
||||
}
|
||||
|
||||
template <typename IdxT, typename Iterator>
|
||||
auto make_general_iterator(
|
||||
Iterator it,
|
||||
const std::vector<int32_t>& shape,
|
||||
const std::vector<int64_t>& strides) {
|
||||
return make_general_iterator<IdxT>(
|
||||
it, 0, shape.size(), const_param(shape), const_param(strides));
|
||||
}
|
||||
|
||||
template <typename IdxT, typename Iterator>
|
||||
auto make_general_iterators(
|
||||
Iterator it,
|
||||
IdxT size,
|
||||
const std::vector<int32_t>& shape,
|
||||
const std::vector<int64_t>& strides) {
|
||||
auto ndim = shape.size();
|
||||
auto shape_arg = const_param(shape);
|
||||
auto strides_arg = const_param(strides);
|
||||
return std::make_pair(
|
||||
make_general_iterator<IdxT>(it, 0, ndim, shape_arg, strides_arg),
|
||||
make_general_iterator<IdxT>(it, size, ndim, shape_arg, strides_arg));
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
26
mlx/backend/cuda/kernel_utils.cu
Normal file
26
mlx/backend/cuda/kernel_utils.cu
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2) {
|
||||
Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2);
|
||||
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||
}
|
||||
|
||||
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) {
|
||||
Dims dims = get_2d_grid_dims_common(shape, strides);
|
||||
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||
}
|
||||
|
||||
dim3 get_2d_grid_dims(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor) {
|
||||
Dims dims = get_2d_grid_dims_common(shape, strides, divisor);
|
||||
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
131
mlx/backend/cuda/kernel_utils.cuh
Normal file
131
mlx/backend/cuda/kernel_utils.cuh
Normal file
@@ -0,0 +1,131 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file includes host-only utilies for writing CUDA kernels, the difference
|
||||
// from backend/cuda/kernels/utils.cuh is that the latter file only include
|
||||
// device-only code.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <fmt/format.h>
|
||||
#include <cuda/cmath>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Convert a number between 1~3 to constexpr.
|
||||
#define MLX_SWITCH_1_2_3(N, NDIM, ...) \
|
||||
switch (N) { \
|
||||
case 1: { \
|
||||
constexpr int NDIM = 1; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case 2: { \
|
||||
constexpr int NDIM = 2; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case 3: { \
|
||||
constexpr int NDIM = 3; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
}
|
||||
|
||||
// Like MLX_SWITCH_ALL_TYPES but for booleans.
|
||||
#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \
|
||||
if (BOOL) { \
|
||||
constexpr bool BOOL_ALIAS = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool BOOL_ALIAS = false; \
|
||||
__VA_ARGS__; \
|
||||
}
|
||||
|
||||
// Maps CPU types to CUDA types.
|
||||
template <typename T>
|
||||
struct CTypeToCudaType {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CTypeToCudaType<float16_t> {
|
||||
using type = __half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CTypeToCudaType<bfloat16_t> {
|
||||
using type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CTypeToCudaType<complex64_t> {
|
||||
using type = cuComplex;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using cuda_type_t = typename CTypeToCudaType<T>::type;
|
||||
|
||||
// Type traits for detecting floating numbers.
|
||||
template <typename T>
|
||||
inline constexpr bool is_floating_v =
|
||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
||||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
||||
|
||||
// Utility to copy data from vector to array in host.
|
||||
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
}
|
||||
cuda::std::array<T, NDIM> result;
|
||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||
return result;
|
||||
}
|
||||
|
||||
// Compute the grid and block dimensions, check backend/common/utils.h for docs.
|
||||
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);
|
||||
dim3 get_2d_grid_dims(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor);
|
||||
|
||||
// Return a block size that achieves maximum potential occupancy for kernel.
|
||||
template <typename T>
|
||||
inline uint max_occupancy_block_dim(T kernel) {
|
||||
int _, block_dim;
|
||||
CHECK_CUDA_ERROR(cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
|
||||
return block_dim;
|
||||
}
|
||||
|
||||
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
|
||||
// assuming each thread handles |work_per_thread| elements of |arr|.
|
||||
template <typename T>
|
||||
inline std::tuple<dim3, uint> get_launch_args(
|
||||
T kernel,
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread = 1) {
|
||||
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
|
||||
uint block_dim = max_occupancy_block_dim(kernel);
|
||||
if (block_dim > nthreads) {
|
||||
block_dim = nthreads;
|
||||
}
|
||||
dim3 num_blocks;
|
||||
if (large) {
|
||||
num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread);
|
||||
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
|
||||
} else {
|
||||
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
|
||||
}
|
||||
return std::make_tuple(num_blocks, block_dim);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
15
mlx/backend/cuda/kernels/arange.cuh
Normal file
15
mlx/backend/cuda/kernels/arange.cuh
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T>
|
||||
struct Arange {
|
||||
const T start;
|
||||
const T step;
|
||||
|
||||
__device__ T operator()(uint32_t i) const {
|
||||
return start + i * step;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
278
mlx/backend/cuda/kernels/binary_ops.cuh
Normal file
278
mlx/backend/cuda/kernels/binary_ops.cuh
Normal file
@@ -0,0 +1,278 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda/std/array>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct Add {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct FloorDivide {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x / y;
|
||||
} else {
|
||||
return trunc(x / y);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
if constexpr (cuda::std::is_signed_v<T>) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
} else {
|
||||
return x % y;
|
||||
}
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return x % y;
|
||||
} else {
|
||||
T r = fmod(x, y);
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r = r + y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x == y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
if constexpr (std::is_same_v<T, cuComplex>) {
|
||||
return x == y ||
|
||||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) &&
|
||||
isnan(cuCimagf(y))) ||
|
||||
(cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) &&
|
||||
isnan(cuCimagf(y))) ||
|
||||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) &&
|
||||
cuCimagf(x) == cuCimagf(y));
|
||||
} else {
|
||||
return x == y || (isnan(x) && isnan(y));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x > y;
|
||||
}
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x >= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x <= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if (isnan(x) || isnan(y)) {
|
||||
return cuda::std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
T maxval = max(x, y);
|
||||
T minval = min(x, y);
|
||||
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
|
||||
maxval == cuda::std::numeric_limits<T>::infinity())
|
||||
? maxval
|
||||
: T(float(maxval) + log1p(expf(minval - maxval)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return max(x, y);
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
} else {
|
||||
if (isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return min(x, y);
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
} else {
|
||||
if (isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
if constexpr (std::is_same_v<T, cuComplex>) {
|
||||
return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y);
|
||||
} else {
|
||||
return x != y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
template <typename T>
|
||||
__device__ T operator()(T base, T exp) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto x_theta = atan2f(base.y, base.x);
|
||||
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
|
||||
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
|
||||
auto phase = exp.y * x_ln_r + exp.x * x_theta;
|
||||
return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase));
|
||||
} else {
|
||||
return powf(base, exp);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Subtract {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x && y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x || y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x & y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x | y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x << y;
|
||||
};
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T y, T x) {
|
||||
return atan2f(y, x);
|
||||
}
|
||||
};
|
||||
|
||||
struct DivMod {
|
||||
template <typename T>
|
||||
__device__ cuda::std::array<T, 2> operator()(T x, T y) {
|
||||
return {FloorDivide{}(x, y), Remainder{}(x, y)};
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
59
mlx/backend/cuda/kernels/cast_op.cuh
Normal file
59
mlx/backend/cuda/kernels/cast_op.cuh
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// An op that does static_cast, with custom conversions for some types.
|
||||
template <typename SrcT, typename DstT, typename = void>
|
||||
struct CastOp {
|
||||
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, DstT>;
|
||||
|
||||
__device__ DstT operator()(SrcT x) {
|
||||
return static_cast<DstT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
// Converting a complex number to real number discards the imaginary part.
|
||||
template <typename DstT>
|
||||
struct CastOp<
|
||||
cuComplex,
|
||||
DstT,
|
||||
cuda::std::enable_if_t<!cuda::std::is_same_v<cuComplex, DstT>>> {
|
||||
static constexpr bool is_castable = cuda::std::is_convertible_v<float, DstT>;
|
||||
|
||||
__device__ DstT operator()(cuComplex x) {
|
||||
static_assert(!cuda::std::is_same_v<cuComplex, DstT>);
|
||||
return static_cast<DstT>(cuCrealf(x));
|
||||
}
|
||||
};
|
||||
|
||||
// Allow converting a real number to complex number.
|
||||
template <typename SrcT>
|
||||
struct CastOp<
|
||||
SrcT,
|
||||
cuComplex,
|
||||
cuda::std::enable_if_t<!cuda::std::is_same_v<SrcT, cuComplex>>> {
|
||||
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, float>;
|
||||
|
||||
__device__ cuComplex operator()(SrcT x) {
|
||||
static_assert(!cuda::std::is_same_v<SrcT, cuComplex>);
|
||||
return cuComplex{static_cast<float>(x), 0};
|
||||
}
|
||||
};
|
||||
|
||||
// Return an iterator that cast the value to DstT using CastOp.
|
||||
template <typename DstT, typename Iterator>
|
||||
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
|
||||
if constexpr (std::is_same_v<SrcT, DstT>) {
|
||||
return it;
|
||||
} else {
|
||||
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
240
mlx/backend/cuda/kernels/cucomplex_math.cuh
Normal file
240
mlx/backend/cuda/kernels/cucomplex_math.cuh
Normal file
@@ -0,0 +1,240 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
// Copyright © 2017-2024 The Simons Foundation, Inc.
|
||||
//
|
||||
// FINUFFT is licensed under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance with the
|
||||
// License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Forked from
|
||||
// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
|
||||
// This header provides some helper functions for cuComplex types.
|
||||
// It mainly wraps existing CUDA implementations to provide operator overloads
|
||||
// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are
|
||||
// all provided by CUDA
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCadd(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCsub(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCmul(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCdiv(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b));
|
||||
double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b));
|
||||
return make_cuDoubleComplex(r, i);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator==(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator!=(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a));
|
||||
double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b));
|
||||
return mag_a > mag_b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>=(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return a > b || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return b > a;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<=(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return b > a || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator+(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator+(double a, const cuDoubleComplex& b) {
|
||||
return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator-(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator-(double a, const cuDoubleComplex& b) {
|
||||
return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator*(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator*(double a, const cuDoubleComplex& b) {
|
||||
return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator/(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator/(double a, const cuDoubleComplex& b) {
|
||||
double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b);
|
||||
return make_cuDoubleComplex(
|
||||
(a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator+(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCaddf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator-(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCsubf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator*(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCmulf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator/(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCdivf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator%(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b));
|
||||
float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b));
|
||||
return make_cuFloatComplex(r, i);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator==(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator!=(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a));
|
||||
float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b));
|
||||
return mag_a > mag_b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>=(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return a > b || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return b > a;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<=(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return b > a || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator+(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator+(float a, const cuFloatComplex& b) {
|
||||
return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator-(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator-(float a, const cuFloatComplex& b) {
|
||||
return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator*(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator*(float a, const cuFloatComplex& b) {
|
||||
return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator/(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator/(float a, const cuFloatComplex& b) {
|
||||
float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b);
|
||||
return make_cuFloatComplex(
|
||||
(a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom);
|
||||
}
|
194
mlx/backend/cuda/kernels/fp16_math.cuh
Normal file
194
mlx/backend/cuda/kernels/fp16_math.cuh
Normal file
@@ -0,0 +1,194 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/type_traits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Unary ops for half types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x); \
|
||||
} else { \
|
||||
return ::NAME(x); \
|
||||
} \
|
||||
}
|
||||
#else
|
||||
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x); \
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||
return HALF_OP(x); \
|
||||
} else { \
|
||||
return ::NAME(x); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
#define MLX_DEFINE_UNARY_OP_FALLBCK(NAME) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return ::NAME(__half2float(x)); \
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||
return ::NAME(__bfloat162float(x)); \
|
||||
} else { \
|
||||
return ::NAME(x); \
|
||||
} \
|
||||
}
|
||||
|
||||
MLX_DEFINE_UNARY_OP(abs, __habs)
|
||||
MLX_DEFINE_UNARY_OP(ceil, hceil)
|
||||
MLX_DEFINE_UNARY_OP(cos, hcos)
|
||||
MLX_DEFINE_UNARY_OP(exp, hexp)
|
||||
MLX_DEFINE_UNARY_OP(floor, hfloor)
|
||||
MLX_DEFINE_UNARY_OP(isnan, __hisnan)
|
||||
MLX_DEFINE_UNARY_OP(log, hlog)
|
||||
MLX_DEFINE_UNARY_OP(log2, hlog2)
|
||||
MLX_DEFINE_UNARY_OP(log10, hlog10)
|
||||
MLX_DEFINE_UNARY_OP(rint, hrint)
|
||||
MLX_DEFINE_UNARY_OP(rsqrt, hrsqrt)
|
||||
MLX_DEFINE_UNARY_OP(sin, hsin)
|
||||
MLX_DEFINE_UNARY_OP(sqrt, hsqrt)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(acos)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(acosh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(asin)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(asinh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(atan)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(atanh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(cosh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(log1p)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(sinh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(tan)
|
||||
#if __CUDA_ARCH__ >= 1280
|
||||
MLX_DEFINE_UNARY_OP(tanh, htanh)
|
||||
#else
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(tanh)
|
||||
#endif
|
||||
|
||||
#undef MLX_DEFINE_UNARY_OP
|
||||
#undef MLX_DEFINE_UNARY_OP_FALLBCK
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Binary ops for half types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x, T y) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x, y); \
|
||||
} else { \
|
||||
return ::NAME(x, y); \
|
||||
} \
|
||||
}
|
||||
#else
|
||||
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x, T y) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x, y); \
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||
return HALF_OP(x, y); \
|
||||
} else { \
|
||||
return ::NAME(x, y); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
MLX_DEFINE_BINARY_OP(max, __hmax)
|
||||
MLX_DEFINE_BINARY_OP(min, __hmin)
|
||||
|
||||
#undef MLX_DEFINE_BINARY_OP
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T fmod(T x, T y) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return __float2half(::fmod(__half2float(x), __half2float(y)));
|
||||
#if CUDART_VERSION >= 12000 || __CUDA_ARCH__ >= 800
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return __float2bfloat16(::fmod(__bfloat162float(x), __bfloat162float(y)));
|
||||
#endif
|
||||
} else {
|
||||
return ::fmod(x, y);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Additional C++ operator overrides between half types and native types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_integral_except =
|
||||
cuda::std::is_integral_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_arithmetic_except =
|
||||
cuda::std::is_arithmetic_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(HALF x, T y) { \
|
||||
return FLOAT2HALF(HALF2FLOAT(x) OP static_cast<float>(y)); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(T x, HALF y) { \
|
||||
return FLOAT2HALF(static_cast<float>(x) OP HALF2FLOAT(y)); \
|
||||
}
|
||||
|
||||
#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(HALF x, T y) { \
|
||||
return HALF2FLOAT(x) OP static_cast<float>(y); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(T x, HALF y) { \
|
||||
return static_cast<float>(y) OP HALF2FLOAT(x); \
|
||||
}
|
||||
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, !=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=)
|
||||
|
||||
#undef MLX_DEFINE_HALF_OP
|
||||
#undef MLX_DEFINE_HALF_CMP
|
||||
|
||||
} // namespace mlx::core::cu
|
349
mlx/backend/cuda/kernels/unary_ops.cuh
Normal file
349
mlx/backend/cuda/kernels/unary_ops.cuh
Normal file
@@ -0,0 +1,349 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct Abs {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||
return x;
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0};
|
||||
} else {
|
||||
return abs(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return acos(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return acosh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return asin(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return asinh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return atan(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return atanh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseInvert {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return ~x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x;
|
||||
} else {
|
||||
return ceil(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
__device__ cuComplex operator()(cuComplex x) {
|
||||
return {cuCrealf(x), -cuCimagf(x)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
cos(cuCrealf(x)) * cosh(cuCimagf(x)),
|
||||
-sin(cuCrealf(x)) * sinh(cuCimagf(x))};
|
||||
} else {
|
||||
return cos(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
cosh(cuCrealf(x)) * cos(cuCimagf(x)),
|
||||
sinh(cuCrealf(x)) * sin(cuCimagf(x))};
|
||||
} else {
|
||||
return cosh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return erf(__half2float(x));
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return erf(__bfloat162float(x));
|
||||
} else {
|
||||
return erf(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return erfinv(__half2float(x));
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return erfinv(__bfloat162float(x));
|
||||
} else {
|
||||
return erfinv(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto m = exp(cuCrealf(x));
|
||||
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
|
||||
} else {
|
||||
return exp(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Expm1 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return expm1(__half2float(x));
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return expm1(__bfloat162float(x));
|
||||
} else {
|
||||
return expm1(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x;
|
||||
} else {
|
||||
return floor(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Imag {
|
||||
__device__ float operator()(cuComplex x) {
|
||||
return cuCimagf(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log2(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log10(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log1p(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
__device__ bool operator()(bool x) {
|
||||
return !x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return 0 - x;
|
||||
} else {
|
||||
return -x;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Real {
|
||||
__device__ float operator()(cuComplex x) {
|
||||
return cuCrealf(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {rint(cuCrealf(x)), rint(cuCimagf(x))};
|
||||
} else {
|
||||
return rint(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
T y = 1 / (1 + exp(-abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||
return x != 0;
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (cuCrealf(x) == 0 && cuCimagf(x) == 0) {
|
||||
return x;
|
||||
} else {
|
||||
return x / Abs()(x);
|
||||
}
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return static_cast<float>((x > T(0.f)) - (x < T(0.f)));
|
||||
} else {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
sin(cuCrealf(x)) * cosh(cuCimagf(x)),
|
||||
cos(cuCrealf(x)) * sinh(cuCimagf(x))};
|
||||
} else {
|
||||
return sin(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
sinh(cuCrealf(x)) * cos(cuCimagf(x)),
|
||||
cosh(cuCrealf(x)) * sin(cuCimagf(x))};
|
||||
} else {
|
||||
return sinh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return x * x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
float tan_a = tan(cuCrealf(x));
|
||||
float tanh_b = tanh(cuCimagf(x));
|
||||
float t1 = tan_a * tanh_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
|
||||
} else {
|
||||
return tan(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
float tanh_a = tanh(cuCrealf(x));
|
||||
float tan_b = tan(cuCimagf(x));
|
||||
float t1 = tanh_a * tan_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
||||
} else {
|
||||
return tanh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
104
mlx/backend/cuda/kernels/utils.cuh
Normal file
104
mlx/backend/cuda/kernels/utils.cuh
Normal file
@@ -0,0 +1,104 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file must not include any host-only code, utilies that work under both
|
||||
// host and device can be put here.
|
||||
//
|
||||
// See more about the requirements at:
|
||||
// https://docs.nvidia.com/cuda/nvrtc/#language
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda/std/array>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/tuple>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CUDA kernel utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// To pass shape/strides to kernels via constant memory, their size must be
|
||||
// known at compile time.
|
||||
#define MAX_NDIM 8
|
||||
|
||||
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
||||
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
// Optimize when the ndim is known at compile time.
|
||||
template <int NDIM, typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) {
|
||||
IdxT loc = 0;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM, typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
}
|
||||
|
||||
// Optimized version when ndim is larger than 4.
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = elem_to_loc_nd<3>(elem, shape, strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
int ndim) {
|
||||
auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
474
mlx/backend/cuda/matmul.cpp
Normal file
474
mlx/backend/cuda/matmul.cpp
Normal file
@@ -0,0 +1,474 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/matmul.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||
|
||||
void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||
if (err != CUBLAS_STATUS_SUCCESS) {
|
||||
// TODO: Use cublasGetStatusString when it is widely available.
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
|
||||
}
|
||||
}
|
||||
|
||||
class MatMul {
|
||||
public:
|
||||
MatMul(
|
||||
Device& device,
|
||||
Dtype dtype,
|
||||
bool a_transposed,
|
||||
uint64_t a_rows,
|
||||
uint64_t a_cols,
|
||||
int64_t lda,
|
||||
bool b_transposed,
|
||||
uint64_t b_rows,
|
||||
uint64_t b_cols,
|
||||
int64_t ldb,
|
||||
int32_t batch_count,
|
||||
int64_t a_batch_stride,
|
||||
int64_t b_batch_stride) {
|
||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||
|
||||
auto type = dtype_to_cuda_type(dtype);
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
||||
&matmul_desc_, dtype_to_compute_type(dtype), type));
|
||||
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
||||
&pointer_mode,
|
||||
sizeof(int32_t)));
|
||||
cublasOperation_t op = CUBLAS_OP_N;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
&op,
|
||||
sizeof(cublasOperation_t)));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_TRANSB,
|
||||
&op,
|
||||
sizeof(cublasOperation_t)));
|
||||
|
||||
a_desc_ = create_matrix_layout(
|
||||
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
||||
b_desc_ = create_matrix_layout(
|
||||
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
||||
out_desc_ = create_matrix_layout(
|
||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||
|
||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||
// for Hopper+:
|
||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||
uint64_t MiB = 1024 * 1024;
|
||||
uint64_t workspace_size =
|
||||
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
||||
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
||||
pref_,
|
||||
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&workspace_size,
|
||||
sizeof(uint64_t)));
|
||||
}
|
||||
|
||||
MatMul(
|
||||
Device& device,
|
||||
Dtype dtype,
|
||||
bool a_transposed,
|
||||
uint64_t a_rows,
|
||||
uint64_t a_cols,
|
||||
int64_t lda,
|
||||
bool b_transposed,
|
||||
uint64_t b_rows,
|
||||
uint64_t b_cols,
|
||||
int64_t ldb,
|
||||
bool c_transposed,
|
||||
int64_t ldc,
|
||||
int32_t batch_count,
|
||||
int64_t a_batch_stride,
|
||||
int64_t b_batch_stride,
|
||||
int64_t c_batch_stride)
|
||||
: MatMul(
|
||||
device,
|
||||
dtype,
|
||||
a_transposed,
|
||||
a_rows,
|
||||
a_cols,
|
||||
lda,
|
||||
b_transposed,
|
||||
b_rows,
|
||||
b_cols,
|
||||
ldb,
|
||||
batch_count,
|
||||
a_batch_stride,
|
||||
b_batch_stride) {
|
||||
auto type = dtype_to_cuda_type(dtype);
|
||||
c_desc_ = create_matrix_layout(
|
||||
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
|
||||
}
|
||||
|
||||
~MatMul() {
|
||||
cublasLtMatrixLayoutDestroy(a_desc_);
|
||||
cublasLtMatrixLayoutDestroy(b_desc_);
|
||||
cublasLtMatrixLayoutDestroy(c_desc_);
|
||||
cublasLtMatrixLayoutDestroy(out_desc_);
|
||||
cublasLtMatmulDescDestroy(matmul_desc_);
|
||||
}
|
||||
|
||||
void run(
|
||||
cu::CommandEncoder& encoder,
|
||||
void* out,
|
||||
void* a,
|
||||
void* b,
|
||||
void* c = nullptr,
|
||||
float alpha = 1,
|
||||
float beta = 0) {
|
||||
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
|
||||
int ret = 0;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
|
||||
encoder.device().lt_handle(),
|
||||
matmul_desc_,
|
||||
a_desc_,
|
||||
b_desc_,
|
||||
out_desc_,
|
||||
out_desc_,
|
||||
pref_,
|
||||
1,
|
||||
&heuristic_,
|
||||
&ret));
|
||||
if (ret == 0) {
|
||||
throw std::runtime_error("Can not find algorithm for matmul.");
|
||||
}
|
||||
}
|
||||
|
||||
array workspace(
|
||||
allocator::malloc(heuristic_.workspaceSize),
|
||||
{static_cast<int>(heuristic_.workspaceSize)},
|
||||
int8);
|
||||
encoder.add_temporary(workspace);
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||
encoder.device().lt_handle(),
|
||||
matmul_desc_,
|
||||
&alpha,
|
||||
a,
|
||||
a_desc_,
|
||||
b,
|
||||
b_desc_,
|
||||
&beta,
|
||||
c ? c : out,
|
||||
c ? c_desc_ : out_desc_,
|
||||
out,
|
||||
out_desc_,
|
||||
&heuristic_.algo,
|
||||
workspace.data<void>(),
|
||||
workspace.nbytes(),
|
||||
stream));
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case uint8:
|
||||
case uint16:
|
||||
case int8:
|
||||
case int16:
|
||||
case int32:
|
||||
return CUBLAS_COMPUTE_32I;
|
||||
case float16:
|
||||
case bfloat16:
|
||||
return CUBLAS_COMPUTE_16F;
|
||||
case float32:
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
case float64:
|
||||
case complex64:
|
||||
return CUBLAS_COMPUTE_64F;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case uint8:
|
||||
return CUDA_R_8U;
|
||||
case uint16:
|
||||
return CUDA_R_16U;
|
||||
case int8:
|
||||
return CUDA_R_8I;
|
||||
case int16:
|
||||
return CUDA_R_16I;
|
||||
case int32:
|
||||
return CUDA_R_32I;
|
||||
case float16:
|
||||
return CUDA_R_16F;
|
||||
case bfloat16:
|
||||
return CUDA_R_16BF;
|
||||
case float32:
|
||||
return CUDA_R_32F;
|
||||
case float64:
|
||||
return CUDA_R_64F;
|
||||
case complex64:
|
||||
return CUDA_C_32F;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
cublasLtMatrixLayout_t create_matrix_layout(
|
||||
cudaDataType_t type,
|
||||
uint64_t rows,
|
||||
uint64_t cols,
|
||||
bool transposed,
|
||||
int64_t ld,
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride) {
|
||||
cublasLtMatrixLayout_t desc;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
|
||||
cublasLtOrder_t order =
|
||||
transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
|
||||
if (batch_count > 1) {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&batch_count,
|
||||
sizeof(int32_t)));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||
&batch_stride,
|
||||
sizeof(int64_t)));
|
||||
}
|
||||
return desc;
|
||||
}
|
||||
|
||||
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
||||
cublasLtMatmulPreference_t pref_{nullptr};
|
||||
cublasLtMatrixLayout_t a_desc_{nullptr};
|
||||
cublasLtMatrixLayout_t b_desc_{nullptr};
|
||||
cublasLtMatrixLayout_t c_desc_{nullptr};
|
||||
cublasLtMatrixLayout_t out_desc_{nullptr};
|
||||
cublasLtMatmulHeuristicResult_t heuristic_;
|
||||
};
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace {
|
||||
|
||||
std::tuple<bool, int64_t, array>
|
||||
check_transpose(std::vector<array>& copies, const Stream& s, const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && stx == arr.shape(-1)) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Matmul::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
assert(inputs.size() == 2);
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
// Return 0s if either input is empty.
|
||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||
array zero(0, a_pre.dtype());
|
||||
encoder.add_temporary(zero);
|
||||
fill_gpu(zero, out, s);
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
|
||||
|
||||
for (auto& temp : copies) {
|
||||
encoder.add_temporary(temp);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||
|
||||
auto batch_count = out.size() / (M * N);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||
b_batch_strides.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_count = 1;
|
||||
|
||||
a_batch_strides = {0};
|
||||
b_batch_strides = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
|
||||
cu::MatMul matmul(
|
||||
encoder.device(),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
M,
|
||||
K,
|
||||
lda,
|
||||
b_transposed,
|
||||
K,
|
||||
N,
|
||||
ldb,
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
|
||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
||||
matmul.run(
|
||||
encoder,
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||
b.data<int8_t>() + b.itemsize() * b_it.loc);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("AddMM::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
assert(inputs.size() == 3);
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
auto& c_pre = inputs[2];
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
|
||||
auto [c_transposed, ldc, c] = check_transpose(copies, s, c_pre);
|
||||
|
||||
for (auto& temp : copies) {
|
||||
encoder.add_temporary(temp);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
auto [batch_shape, a_batch_strides, b_batch_strides, c_batch_strides] =
|
||||
collapse_batches(a, b, c);
|
||||
|
||||
auto batch_count = out.size() / (M * N);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||
c_batch_strides.back() == M * c.strides()[c.ndim() - 2] &&
|
||||
b_batch_strides.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_count = 1;
|
||||
|
||||
a_batch_strides = {0};
|
||||
b_batch_strides = {0};
|
||||
c_batch_strides = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
|
||||
cu::MatMul matmul(
|
||||
encoder.device(),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
M,
|
||||
K,
|
||||
lda,
|
||||
b_transposed,
|
||||
K,
|
||||
N,
|
||||
ldb,
|
||||
c_transposed,
|
||||
ldc,
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back(),
|
||||
c_batch_strides.back());
|
||||
|
||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
||||
matmul.run(
|
||||
encoder,
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
||||
c.data<int8_t>() + c.itemsize() * c_it.loc,
|
||||
alpha_,
|
||||
beta_);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
c_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
11
mlx/backend/cuda/no_cuda.cpp
Normal file
11
mlx/backend/cuda/no_cuda.cpp
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
125
mlx/backend/cuda/primitives.cu
Normal file
125
mlx/backend/cuda/primitives.cu
Normal file
@@ -0,0 +1,125 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/arange.cuh"
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/transform.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Arange::eval_gpu");
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&, this](cudaStream_t stream) {
|
||||
MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, {
|
||||
using OutType = cuda_type_t<CTYPE>;
|
||||
CTYPE step =
|
||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
||||
thrust::transform(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::counting_iterator<uint32_t>(0),
|
||||
thrust::counting_iterator<uint32_t>(out.data_size()),
|
||||
thrust::device_pointer_cast(out.data<OutType>()),
|
||||
cu::Arange<OutType>{
|
||||
static_cast<OutType>(start_), static_cast<OutType>(step)});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
bool fast::ScaledDotProductAttention::use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
#define NO_GPU_USE_FALLBACK(func) \
|
||||
bool func::use_fallback(Stream s) { \
|
||||
return true; \
|
||||
} \
|
||||
NO_GPU_MULTI(func)
|
||||
|
||||
#define NO_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(ArgReduce)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Gather)
|
||||
NO_GPU(GatherAxis)
|
||||
NO_GPU(GatherMM)
|
||||
NO_GPU(GatherQMM)
|
||||
NO_GPU(Hadamard)
|
||||
NO_GPU(Load)
|
||||
NO_GPU(LogSumExp)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(Reduce)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(ScatterAxis)
|
||||
NO_GPU(Select)
|
||||
NO_GPU(SliceUpdate)
|
||||
NO_GPU(Softmax)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_USE_FALLBACK(LayerNorm)
|
||||
NO_GPU_MULTI(LayerNormVJP)
|
||||
NO_GPU_USE_FALLBACK(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(AllReduce)
|
||||
NO_GPU_MULTI(AllGather)
|
||||
NO_GPU_MULTI(Send)
|
||||
NO_GPU_MULTI(Recv)
|
||||
} // namespace distributed
|
||||
|
||||
} // namespace mlx::core
|
181
mlx/backend/cuda/random.cu
Normal file
181
mlx/backend/cuda/random.cu
Normal file
@@ -0,0 +1,181 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
__constant__ constexpr uint32_t rotations[2][4] = {
|
||||
{13, 15, 26, 6},
|
||||
{17, 29, 16, 24}};
|
||||
|
||||
union rbits {
|
||||
uint2 val;
|
||||
uint8_t bytes[2][4];
|
||||
};
|
||||
|
||||
__device__ rbits threefry2x32_hash(uint2 key, uint2 count) {
|
||||
uint32_t ks[] = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
|
||||
|
||||
rbits v;
|
||||
v.val.x = count.x + ks[0];
|
||||
v.val.y = count.y + ks[1];
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
for (auto r : rotations[i % 2]) {
|
||||
v.val.x += v.val.y;
|
||||
v.val.y = (v.val.y << r) | (v.val.y >> (32 - r));
|
||||
v.val.y ^= v.val.x;
|
||||
}
|
||||
v.val.x += ks[(i + 1) % 3];
|
||||
v.val.y += ks[(i + 2) % 3] + i + 1;
|
||||
}
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
__global__ void rbitsc(
|
||||
const uint32_t* keys,
|
||||
uint8_t* out,
|
||||
dim3 grid_dims,
|
||||
bool odd,
|
||||
uint32_t bytes_per_key) {
|
||||
uint2 index{
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y};
|
||||
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto kidx = 2 * index.x;
|
||||
auto key = uint2{keys[kidx], keys[kidx + 1]};
|
||||
auto half_size = grid_dims.y - odd;
|
||||
out += index.x * bytes_per_key;
|
||||
bool drop_last = odd && (index.y == half_size);
|
||||
auto bits = threefry2x32_hash(
|
||||
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
|
||||
size_t idx = size_t(index.y) << 2;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[idx + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void rbits(
|
||||
const uint32_t* keys,
|
||||
uint8_t* out,
|
||||
dim3 grid_dims,
|
||||
bool odd,
|
||||
uint32_t bytes_per_key,
|
||||
int32_t ndim,
|
||||
const __grid_constant__ Shape key_shape,
|
||||
const __grid_constant__ Strides key_strides) {
|
||||
uint2 index{
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y};
|
||||
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto kidx = 2 * index.x;
|
||||
auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim);
|
||||
auto k2_elem =
|
||||
elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim);
|
||||
auto key = uint2{keys[k1_elem], keys[k2_elem]};
|
||||
auto half_size = grid_dims.y - odd;
|
||||
out += size_t(index.x) * bytes_per_key;
|
||||
bool drop_last = odd && (index.y == half_size);
|
||||
auto bits = threefry2x32_hash(
|
||||
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
|
||||
size_t idx = size_t(index.y) << 2;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[idx + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("RandomBits::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// keys has shape (N1, ..., NK, 2)
|
||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||
auto& keys = inputs[0];
|
||||
uint32_t num_keys = keys.size() / 2;
|
||||
|
||||
uint32_t elems_per_key = out.size() / num_keys;
|
||||
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||
uint32_t half_size = out_per_key / 2;
|
||||
bool odd = out_per_key % 2;
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(keys);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dim3 grid_dims{num_keys, half_size + odd};
|
||||
dim3 block_dims = get_block_dims(grid_dims.x, grid_dims.y, 1);
|
||||
dim3 num_blocks{
|
||||
cuda::ceil_div(grid_dims.x, block_dims.x),
|
||||
cuda::ceil_div(grid_dims.y, block_dims.y)};
|
||||
if (keys.flags().row_contiguous) {
|
||||
cu::rbitsc<<<num_blocks, block_dims, 0, stream>>>(
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
grid_dims,
|
||||
odd,
|
||||
bytes_per_key);
|
||||
} else {
|
||||
cu::rbits<<<num_blocks, block_dims, 0, stream>>>(
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
grid_dims,
|
||||
odd,
|
||||
bytes_per_key,
|
||||
keys.ndim(),
|
||||
const_param(keys.shape()),
|
||||
const_param(keys.strides()));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
41
mlx/backend/cuda/slicing.cpp
Normal file
41
mlx/backend/cuda/slicing.cpp
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void concatenate_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
int axis,
|
||||
const Stream& s) {
|
||||
std::vector<int> sizes;
|
||||
sizes.push_back(0);
|
||||
for (auto& p : inputs) {
|
||||
sizes.push_back(p.shape(axis));
|
||||
}
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto strides = out.strides();
|
||||
auto flags = out.flags();
|
||||
flags.row_contiguous = false;
|
||||
flags.col_contiguous = false;
|
||||
flags.contiguous = false;
|
||||
// TODO: Handle concurrent outputs:
|
||||
// https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
out, strides, flags, out_slice.size(), data_offset);
|
||||
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
180
mlx/backend/cuda/sort.cu
Normal file
180
mlx/backend/cuda/sort.cu
Normal file
@@ -0,0 +1,180 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/transform.h>
|
||||
#include <cub/device/device_segmented_sort.cuh>
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
struct ModOp {
|
||||
T divisor;
|
||||
__device__ T operator()(T x) {
|
||||
return x % divisor;
|
||||
}
|
||||
};
|
||||
|
||||
// We can not use any op in eval, make an utility.
|
||||
array swapaxes_in_eval(const array& in, int axis1, int axis2) {
|
||||
std::vector<int> axes(in.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
std::swap(axes[axis1], axes[axis2]);
|
||||
// TODO: Share the code with Transpose::eval.
|
||||
Shape shape(axes.size());
|
||||
Strides strides(in.ndim());
|
||||
for (size_t ax = 0; ax < axes.size(); ++ax) {
|
||||
shape[ax] = in.shape()[axes[ax]];
|
||||
strides[ax] = in.strides()[axes[ax]];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous) {
|
||||
auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides);
|
||||
flags.row_contiguous = row_contiguous;
|
||||
flags.col_contiguous = col_contiguous;
|
||||
}
|
||||
array out(shape, in.dtype(), nullptr, {});
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void segmented_sort_pairs(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(
|
||||
cub::DeviceSegmentedSort::StableSortPairs(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
|
||||
temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(
|
||||
cub::DeviceSegmentedSort::StableSortKeys(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
|
||||
temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
array out = out_;
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
if (axis < 0) {
|
||||
axis += in.ndim();
|
||||
}
|
||||
int nsort = in.shape(axis);
|
||||
int nsegments = in.data_size() / nsort;
|
||||
int last_dim = in.ndim() - 1;
|
||||
|
||||
// If we are not sorting the innermost dimension of a contiguous array,
|
||||
// transpose and make a copy.
|
||||
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
|
||||
if (!is_segmented_sort) {
|
||||
array trans = swapaxes_in_eval(in, axis, last_dim);
|
||||
in = array(trans.shape(), trans.dtype(), nullptr, {});
|
||||
copy_gpu(trans, in, CopyType::General, s);
|
||||
encoder.add_temporary(in);
|
||||
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||
encoder.add_temporary(out);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||
using Type = cuda_type_t<CTYPE>;
|
||||
auto offsets = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0),
|
||||
[nsort] __device__(int i) { return i * nsort; });
|
||||
if (argsort) {
|
||||
// Indices in the sorted dimension.
|
||||
array indices(
|
||||
allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||
encoder.add_temporary(indices);
|
||||
thrust::transform(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::counting_iterator<uint32_t>(0),
|
||||
thrust::counting_iterator<uint32_t>(indices.data_size()),
|
||||
thrust::device_pointer_cast(indices.data<uint32_t>()),
|
||||
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
|
||||
|
||||
// In argsort though we don't need the result of sorted values, the
|
||||
// API requires us to provide an array to store it.
|
||||
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
|
||||
encoder.add_temporary(discard);
|
||||
|
||||
segmented_sort_pairs(
|
||||
encoder,
|
||||
in.data<Type>(),
|
||||
discard.data<Type>(),
|
||||
indices.data<uint32_t>(),
|
||||
out.data<uint32_t>(),
|
||||
in.data_size(),
|
||||
nsegments,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
} else {
|
||||
segmented_sort(
|
||||
encoder,
|
||||
in.data<Type>(),
|
||||
out.data<Type>(),
|
||||
in.data_size(),
|
||||
nsegments,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"CUDA backend does not support sorting complex numbers");
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
if (!is_segmented_sort) {
|
||||
// Swap the sorted axis back.
|
||||
// TODO: Do in-place transpose instead of using a temporary out array.
|
||||
copy_gpu(swapaxes_in_eval(out, axis, last_dim), out_, CopyType::General, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ArgSort::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
gpu_sort(stream(), inputs[0], out, axis_, true);
|
||||
}
|
||||
|
||||
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Sort::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
196
mlx/backend/cuda/unary.cu
Normal file
196
mlx/backend/cuda/unary.cu
Normal file
@@ -0,0 +1,196 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernels/unary_ops.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/transform.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_unary_op() {
|
||||
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
|
||||
std::is_same_v<Op, Sign>) {
|
||||
return std::is_same_v<In, Out>;
|
||||
}
|
||||
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcCosh> ||
|
||||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
|
||||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Sigmoid> ||
|
||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||
!std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor> ||
|
||||
std::is_same_v<Op, Square>) {
|
||||
return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>;
|
||||
}
|
||||
if (std::is_same_v<Op, Conjugate>) {
|
||||
return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>;
|
||||
}
|
||||
if (std::is_same_v<Op, Cos> || std::is_same_v<Op, Cosh> ||
|
||||
std::is_same_v<Op, Exp> || std::is_same_v<Op, Round> ||
|
||||
std::is_same_v<Op, Sin> || std::is_same_v<Op, Sinh> ||
|
||||
std::is_same_v<Op, Tan> || std::is_same_v<Op, Tanh>) {
|
||||
return std::is_same_v<In, Out> &&
|
||||
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
|
||||
}
|
||||
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
|
||||
return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogicalNot>) {
|
||||
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename Op>
|
||||
void unary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
auto& in = inputs[0];
|
||||
if (in.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
||||
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
auto policy = cu::thrust_policy(stream);
|
||||
auto in_ptr = thrust::device_pointer_cast(in.data<InType>());
|
||||
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||
if (in.flags().contiguous) {
|
||||
thrust::transform(
|
||||
policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op());
|
||||
} else {
|
||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
|
||||
in_ptr, in.data_size(), shape, strides);
|
||||
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do unary op {} on input of {} with output of {}.",
|
||||
op,
|
||||
dtype_to_string(in.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
set_unary_output_data(inputs[0], out);
|
||||
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
#define UNARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = out.primitive().stream(); \
|
||||
unary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
UNARY_GPU(Abs)
|
||||
UNARY_GPU(ArcCos)
|
||||
UNARY_GPU(ArcCosh)
|
||||
UNARY_GPU(ArcSin)
|
||||
UNARY_GPU(ArcSinh)
|
||||
UNARY_GPU(ArcTan)
|
||||
UNARY_GPU(ArcTanh)
|
||||
UNARY_GPU(BitwiseInvert)
|
||||
UNARY_GPU(Ceil)
|
||||
UNARY_GPU(Conjugate)
|
||||
UNARY_GPU(Cos)
|
||||
UNARY_GPU(Cosh)
|
||||
UNARY_GPU(Erf)
|
||||
UNARY_GPU(ErfInv)
|
||||
UNARY_GPU(Exp)
|
||||
UNARY_GPU(Expm1)
|
||||
UNARY_GPU(Floor)
|
||||
UNARY_GPU(Imag)
|
||||
UNARY_GPU(Log1p)
|
||||
UNARY_GPU(LogicalNot)
|
||||
UNARY_GPU(Negative)
|
||||
UNARY_GPU(Real)
|
||||
UNARY_GPU(Sigmoid)
|
||||
UNARY_GPU(Sign)
|
||||
UNARY_GPU(Sin)
|
||||
UNARY_GPU(Sinh)
|
||||
UNARY_GPU(Square)
|
||||
UNARY_GPU(Tan)
|
||||
UNARY_GPU(Tanh)
|
||||
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Log::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_op_gpu<cu::Log>(inputs, out, op, s);
|
||||
break;
|
||||
case Base::two:
|
||||
unary_op_gpu<cu::Log2>(inputs, out, op, s);
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_op_gpu<cu::Log10>(inputs, out, op, s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Round::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
auto& s = out.primitive().stream();
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_op_gpu<cu::Round>(inputs, out, get_primitive_string(this), s);
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Sort::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
if (recip_) {
|
||||
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);
|
||||
} else {
|
||||
unary_op_gpu<cu::Sqrt>(inputs, out, "Sqrt", s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
26
mlx/backend/cuda/utils.cpp
Normal file
26
mlx/backend/cuda/utils.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
CudaStream::CudaStream(cu::Device& device) {
|
||||
device.make_current();
|
||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
|
||||
}
|
||||
|
||||
CudaStream::~CudaStream() {
|
||||
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
||||
}
|
||||
|
||||
void check_cuda_error(const char* name, cudaError_t err) {
|
||||
if (err != cudaSuccess) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed: {}", name, cudaGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
38
mlx/backend/cuda/utils.h
Normal file
38
mlx/backend/cuda/utils.h
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file include utilies that are used by C++ code (i.e. .cpp files).
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
class Device;
|
||||
}
|
||||
|
||||
// Cuda stream managed with RAII.
|
||||
class CudaStream {
|
||||
public:
|
||||
explicit CudaStream(cu::Device& device);
|
||||
~CudaStream();
|
||||
|
||||
CudaStream(const CudaStream&) = delete;
|
||||
CudaStream& operator=(const CudaStream&) = delete;
|
||||
|
||||
operator cudaStream_t() const {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
private:
|
||||
cudaStream_t stream_;
|
||||
};
|
||||
|
||||
// Throw exception if the cuda API does not succeed.
|
||||
void check_cuda_error(const char* name, cudaError_t err);
|
||||
|
||||
// The macro version that prints the command that failed.
|
||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||
|
||||
} // namespace mlx::core
|
90
mlx/backend/cuda/worker.cpp
Normal file
90
mlx/backend/cuda/worker.cpp
Normal file
@@ -0,0 +1,90 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
Worker::Worker()
|
||||
: signal_stream_(device(mlx::core::Device::gpu)),
|
||||
worker_(&Worker::thread_fn, this) {}
|
||||
|
||||
Worker::~Worker() {
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
stop_ = true;
|
||||
}
|
||||
worker_event_.signal(batch_ + 1);
|
||||
worker_.join();
|
||||
}
|
||||
|
||||
void Worker::add_task(std::function<void()> task) {
|
||||
pending_tasks_.push_back(std::move(task));
|
||||
}
|
||||
|
||||
void Worker::consume_in_this_thread() {
|
||||
for (auto& task : pending_tasks_) {
|
||||
task();
|
||||
}
|
||||
pending_tasks_.clear();
|
||||
}
|
||||
|
||||
void Worker::end_batch() {
|
||||
batch_++;
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
worker_tasks_[batch_] = std::move(pending_tasks_);
|
||||
}
|
||||
uncommited_batches_++;
|
||||
}
|
||||
|
||||
void Worker::commit() {
|
||||
if (uncommited_batches_ == 0) {
|
||||
return;
|
||||
}
|
||||
uncommited_batches_ = 0;
|
||||
worker_event_.signal(batch_);
|
||||
}
|
||||
|
||||
void Worker::commit(cudaStream_t stream) {
|
||||
if (uncommited_batches_ == 0) {
|
||||
return;
|
||||
}
|
||||
uncommited_batches_ = 0;
|
||||
// Signal the |worker_event_| in |signal_stream_| after the kernels in
|
||||
// |stream_| finish running.
|
||||
signal_event_.record(stream);
|
||||
signal_event_.wait(signal_stream_);
|
||||
worker_event_.signal(signal_stream_, batch_);
|
||||
}
|
||||
|
||||
void Worker::thread_fn() {
|
||||
// The worker thread is safe to free buffers.
|
||||
allocator().register_this_thread();
|
||||
|
||||
while (!stop_) {
|
||||
uint64_t batch = worker_event_.value();
|
||||
Tasks tasks;
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
// Move tasks in signaled batches.
|
||||
auto end = worker_tasks_.upper_bound(batch);
|
||||
for (auto it = worker_tasks_.begin(); it != end; ++it) {
|
||||
if (tasks.empty()) {
|
||||
tasks = std::move(it->second);
|
||||
} else {
|
||||
std::move(
|
||||
it->second.begin(), it->second.end(), std::back_inserter(tasks));
|
||||
}
|
||||
}
|
||||
worker_tasks_.erase(worker_tasks_.begin(), end);
|
||||
}
|
||||
for (auto& task : tasks) {
|
||||
task();
|
||||
}
|
||||
worker_event_.wait(batch + 1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
68
mlx/backend/cuda/worker.h
Normal file
68
mlx/backend/cuda/worker.h
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Run tasks in worker thread, synchronized with cuda stream.
|
||||
class Worker {
|
||||
public:
|
||||
Worker();
|
||||
~Worker();
|
||||
|
||||
Worker(const Worker&) = delete;
|
||||
Worker& operator=(const Worker&) = delete;
|
||||
|
||||
// Add a pending |task| that will run when consumed or commited.
|
||||
void add_task(std::function<void()> task);
|
||||
|
||||
// Run pending tasks immediately in current thread.
|
||||
void consume_in_this_thread();
|
||||
|
||||
// Put pending tasks in a batch.
|
||||
void end_batch();
|
||||
|
||||
// Inform worker thread to run current batches now.
|
||||
void commit();
|
||||
|
||||
// Inform worker thread to run current batches after kernels in |stream|
|
||||
// finish running.
|
||||
void commit(cudaStream_t stream);
|
||||
|
||||
// Return how many batches have been added but not committed yet.
|
||||
size_t uncommited_batches() const {
|
||||
return uncommited_batches_;
|
||||
}
|
||||
|
||||
private:
|
||||
void thread_fn();
|
||||
|
||||
uint64_t batch_{0};
|
||||
size_t uncommited_batches_{0};
|
||||
|
||||
// Cuda stream and event for signaling kernel completion.
|
||||
CudaStream signal_stream_;
|
||||
CudaEvent signal_event_;
|
||||
|
||||
// Worker thread.
|
||||
SharedEvent worker_event_;
|
||||
std::thread worker_;
|
||||
std::mutex worker_mutex_;
|
||||
bool stop_{false};
|
||||
|
||||
// Tasks are put in |pending_tasks_| first, and then moved to
|
||||
// |worker_tasks_| when end_batch() is called.
|
||||
using Tasks = std::vector<std::function<void()>>;
|
||||
Tasks pending_tasks_;
|
||||
std::map<uint64_t, Tasks> worker_tasks_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
5
mlx/backend/gpu/CMakeLists.txt
Normal file
5
mlx/backend/gpu/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)
|
9
mlx/backend/gpu/available.h
Normal file
9
mlx/backend/gpu/available.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::gpu
|
49
mlx/backend/gpu/copy.cpp
Normal file
49
mlx/backend/gpu/copy.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
bool donated = set_copy_output_data(in, out, ctype);
|
||||
if (donated && in.dtype() == out.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu_inplace(in, out, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& i_strides,
|
||||
int64_t i_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -5,6 +5,8 @@
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Generic copy inplace
|
@@ -8,14 +8,11 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
void new_stream(Stream stream);
|
||||
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||
|
||||
void eval(array& arr);
|
||||
void finalize(Stream s);
|
||||
void synchronize(Stream s);
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
} // namespace mlx::core::gpu
|
225
mlx/backend/gpu/primitives.cpp
Normal file
225
mlx/backend/gpu/primitives.cpp
Normal file
@@ -0,0 +1,225 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
#if defined(MLX_USE_CUDA)
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#endif
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#if defined(MLX_USE_CUDA)
|
||||
#define MLX_PROFILER_RANGE(message) nvtx3::scoped_range r(message)
|
||||
#else
|
||||
#define MLX_PROFILER_RANGE(message)
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void reshape(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("AsType::eval_gpu");
|
||||
CopyType ctype =
|
||||
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy_gpu(inputs[0], out, ctype);
|
||||
}
|
||||
|
||||
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Broadcast::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Concatenate::eval_gpu");
|
||||
concatenate_gpu(inputs, out, axis_, stream());
|
||||
}
|
||||
|
||||
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Contiguous::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
constexpr size_t extra_bytes = 16384;
|
||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||
(in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy_gpu(in, out, CopyType::General);
|
||||
}
|
||||
}
|
||||
|
||||
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Copy::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void CustomTransforms::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
MLX_PROFILER_RANGE("CustomTransforms::eval_gpu");
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void Depends::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
MLX_PROFILER_RANGE("Depends::eval_gpu");
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Full::eval_gpu");
|
||||
auto in = inputs[0];
|
||||
CopyType ctype;
|
||||
if (in.data_size() == 1) {
|
||||
ctype = CopyType::Scalar;
|
||||
} else if (in.flags().contiguous) {
|
||||
ctype = CopyType::Vector;
|
||||
} else {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu(in, out, ctype);
|
||||
}
|
||||
|
||||
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Flatten::eval_gpu");
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("NumberOfElements::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Inputs must be base input array and scalar val array
|
||||
assert(inputs.size() == 2);
|
||||
auto& in = inputs[0];
|
||||
auto& val = inputs[1];
|
||||
|
||||
// Padding value must be a scalar
|
||||
assert(val.size() == 1);
|
||||
|
||||
// Padding value, input and output must be of the same type
|
||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||
|
||||
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
|
||||
}
|
||||
|
||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Reshape::eval_gpu");
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Split::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
MLX_PROFILER_RANGE("Split::eval_gpu");
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Slice::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
slice_gpu(in, out, start_indices_, strides_, stream());
|
||||
}
|
||||
|
||||
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Squeeze::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("StopGradient::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Transpose::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("View::eval_gpu");
|
||||
auto& in = inputs[0];
|
||||
auto ibytes = size_of(in.dtype());
|
||||
auto obytes = size_of(out.dtype());
|
||||
// Conditions for buffer copying (disjunction):
|
||||
// - type size is the same
|
||||
// - type size is smaller and the last axis is contiguous
|
||||
// - the entire array is row contiguous
|
||||
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
|
||||
in.flags().row_contiguous) {
|
||||
auto strides = in.strides();
|
||||
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
|
||||
strides[i] *= ibytes;
|
||||
strides[i] /= obytes;
|
||||
}
|
||||
out.copy_shared_buffer(
|
||||
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
||||
} else {
|
||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||
copy_gpu_inplace(in, tmp, CopyType::General, stream());
|
||||
|
||||
auto flags = out.flags();
|
||||
flags.contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
44
mlx/backend/gpu/slicing.cpp
Normal file
44
mlx/backend/gpu/slicing.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void slice_gpu(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s) {
|
||||
slice(in, out, start_indices, strides);
|
||||
}
|
||||
|
||||
void pad_gpu(
|
||||
const array& in,
|
||||
const array& val,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Stream& s) {
|
||||
// Fill output with val
|
||||
fill_gpu(val, out, s);
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
|
||||
data_offset += out.strides()[ax] * low_pad_size[i];
|
||||
}
|
||||
|
||||
// Extract slice from output where input will be pasted
|
||||
array out_slice(in.shape(), out.dtype(), nullptr, {});
|
||||
out_slice.copy_shared_buffer(
|
||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -93,6 +93,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include "mlx/backend/metal/allocator.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
#include "mlx/memory.h"
|
||||
|
||||
@@ -31,141 +30,18 @@ void* Buffer::raw_ptr() {
|
||||
|
||||
namespace metal {
|
||||
|
||||
namespace {
|
||||
|
||||
BufferCache::BufferCache(ResidencySet& residency_set)
|
||||
: head_(nullptr),
|
||||
tail_(nullptr),
|
||||
pool_size_(0),
|
||||
residency_set_(residency_set) {}
|
||||
|
||||
BufferCache::~BufferCache() {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
clear();
|
||||
}
|
||||
|
||||
int BufferCache::clear() {
|
||||
int n_release = 0;
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
if (holder->buf) {
|
||||
if (!holder->buf->heap()) {
|
||||
residency_set_.erase(holder->buf);
|
||||
}
|
||||
holder->buf->release();
|
||||
n_release++;
|
||||
}
|
||||
delete holder;
|
||||
}
|
||||
buffer_pool_.clear();
|
||||
pool_size_ = 0;
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
return n_release;
|
||||
}
|
||||
|
||||
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
||||
// Find the closest buffer in pool
|
||||
MTL::Buffer* pbuf = nullptr;
|
||||
|
||||
auto it = buffer_pool_.lower_bound(size);
|
||||
|
||||
// Make sure we use most of the available memory
|
||||
while (!pbuf && it != buffer_pool_.end() &&
|
||||
it->first < std::min(2 * size, size + 2 * vm_page_size)) {
|
||||
// Collect from the cache
|
||||
pbuf = it->second->buf;
|
||||
|
||||
// Remove from cache
|
||||
remove_from_list(it->second);
|
||||
delete it->second;
|
||||
it = buffer_pool_.erase(it);
|
||||
}
|
||||
|
||||
if (pbuf) {
|
||||
pool_size_ -= pbuf->length();
|
||||
}
|
||||
|
||||
return pbuf;
|
||||
}
|
||||
|
||||
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
||||
// Add to cache
|
||||
if (buf) {
|
||||
BufferHolder* bh = new BufferHolder(buf);
|
||||
add_at_head(bh);
|
||||
pool_size_ += buf->length();
|
||||
buffer_pool_.insert({buf->length(), bh});
|
||||
}
|
||||
}
|
||||
|
||||
int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
return clear();
|
||||
} else {
|
||||
int n_release = 0;
|
||||
size_t total_bytes_freed = 0;
|
||||
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
if (tail_->buf) {
|
||||
total_bytes_freed += tail_->buf->length();
|
||||
if (!tail_->buf->heap()) {
|
||||
residency_set_.erase(tail_->buf);
|
||||
}
|
||||
tail_->buf->release();
|
||||
tail_->buf = nullptr;
|
||||
n_release++;
|
||||
}
|
||||
remove_from_list(tail_);
|
||||
}
|
||||
pool_size_ -= total_bytes_freed;
|
||||
return n_release;
|
||||
}
|
||||
}
|
||||
|
||||
void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
|
||||
if (!to_add)
|
||||
return;
|
||||
|
||||
if (!head_) {
|
||||
head_ = to_add;
|
||||
tail_ = to_add;
|
||||
} else {
|
||||
head_->prev = to_add;
|
||||
to_add->next = head_;
|
||||
head_ = to_add;
|
||||
}
|
||||
}
|
||||
|
||||
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||
if (!to_remove) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If in the middle
|
||||
if (to_remove->prev && to_remove->next) {
|
||||
to_remove->prev->next = to_remove->next;
|
||||
to_remove->next->prev = to_remove->prev;
|
||||
} else if (to_remove->prev && to_remove == tail_) { // If tail
|
||||
tail_ = to_remove->prev;
|
||||
tail_->next = nullptr;
|
||||
} else if (to_remove == head_ && to_remove->next) { // If head
|
||||
head_ = to_remove->next;
|
||||
head_->prev = nullptr;
|
||||
} else if (to_remove == head_ && to_remove == tail_) { // If only element
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
}
|
||||
|
||||
to_remove->prev = nullptr;
|
||||
to_remove->next = nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MetalAllocator::MetalAllocator()
|
||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||
residency_set_(device_),
|
||||
buffer_cache_(residency_set_) {
|
||||
buffer_cache_(
|
||||
vm_page_size,
|
||||
[](MTL::Buffer* buf) { return buf->length(); },
|
||||
[this](MTL::Buffer* buf) {
|
||||
if (!buf->heap()) {
|
||||
residency_set_.erase(buf);
|
||||
}
|
||||
buf->release();
|
||||
}) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto memsize = std::get<size_t>(device_info().at("memory_size"));
|
||||
auto max_rec_size =
|
||||
@@ -194,6 +70,7 @@ MetalAllocator::~MetalAllocator() {
|
||||
if (heap_) {
|
||||
heap_->release();
|
||||
}
|
||||
buffer_cache_.clear();
|
||||
}
|
||||
|
||||
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||
|
@@ -7,6 +7,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/buffer_cache.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
|
||||
@@ -14,43 +15,6 @@ namespace mlx::core::metal {
|
||||
|
||||
using allocator::Buffer;
|
||||
|
||||
namespace {
|
||||
|
||||
class BufferCache {
|
||||
public:
|
||||
BufferCache(ResidencySet& residency_set);
|
||||
~BufferCache();
|
||||
|
||||
MTL::Buffer* reuse_from_cache(size_t size);
|
||||
void recycle_to_cache(MTL::Buffer* buf);
|
||||
int release_cached_buffers(size_t min_bytes_to_free);
|
||||
size_t cache_size() {
|
||||
return pool_size_;
|
||||
}
|
||||
int clear();
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
public:
|
||||
BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
|
||||
|
||||
BufferHolder* prev;
|
||||
BufferHolder* next;
|
||||
MTL::Buffer* buf;
|
||||
};
|
||||
|
||||
void add_at_head(BufferHolder* to_add);
|
||||
void remove_from_list(BufferHolder* to_remove);
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_;
|
||||
BufferHolder* tail_;
|
||||
size_t pool_size_;
|
||||
ResidencySet& residency_set_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
class MetalAllocator : public allocator::Allocator {
|
||||
/** Allocator for Metal GPUs. */
|
||||
public:
|
||||
@@ -90,7 +54,7 @@ class MetalAllocator : public allocator::Allocator {
|
||||
friend MetalAllocator& allocator();
|
||||
|
||||
// Caching allocator
|
||||
BufferCache buffer_cache_;
|
||||
BufferCache<MTL::Buffer> buffer_cache_;
|
||||
|
||||
ResidencySet residency_set_;
|
||||
|
||||
|
@@ -31,13 +31,13 @@ std::string get_kernel_name(
|
||||
kname = "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname = (large ? "sv2" : "sv");
|
||||
kname = "sv";
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname = (large ? "vs2" : "vs");
|
||||
kname = "vs";
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname = (large ? "vv2" : "vv");
|
||||
kname = "vv";
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname = "g";
|
||||
@@ -51,6 +51,13 @@ std::string get_kernel_name(
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) {
|
||||
if (large) {
|
||||
kname += "2";
|
||||
} else if (work_per_thread > 1) {
|
||||
kname += "n";
|
||||
}
|
||||
}
|
||||
concatenate(kname, "_", op, type_to_name(a));
|
||||
return kname;
|
||||
}
|
||||
@@ -90,7 +97,7 @@ void binary_op_gpu_inplace(
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
work_per_thread = 1;
|
||||
work_per_thread = get_work_per_thread(a.dtype(), out.data_size());
|
||||
}
|
||||
std::string kernel_name =
|
||||
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
|
||||
@@ -137,13 +144,20 @@ void binary_op_gpu_inplace(
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims;
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(out.data_size(), arg_idx++);
|
||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(out.data_size(), arg_idx++);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
}
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
@@ -11,8 +11,6 @@
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
using namespace fmt::literals;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline void build_kernel(
|
||||
@@ -21,21 +19,12 @@ inline void build_kernel(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous,
|
||||
int ndim,
|
||||
bool dynamic_dims,
|
||||
bool use_big_index = false,
|
||||
int work_per_thread = 1) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
|
||||
// Constants are scalars that are captured by value and cannot change
|
||||
auto is_constant = [&constant_ids](const array& x) {
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
NodeNamer namer;
|
||||
bool add_indices = false;
|
||||
int cnt = 0;
|
||||
@@ -45,14 +34,15 @@ inline void build_kernel(
|
||||
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
|
||||
|
||||
// Add the input arguments
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants from the input list
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
// Scalars and contiguous need no strides
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
add_indices = true;
|
||||
@@ -64,6 +54,7 @@ inline void build_kernel(
|
||||
cnt++);
|
||||
}
|
||||
|
||||
std::string idx_type = use_big_index ? "int64_t" : "uint";
|
||||
if (add_indices) {
|
||||
os += fmt::format(
|
||||
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
|
||||
@@ -79,10 +70,11 @@ inline void build_kernel(
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os += fmt::format(
|
||||
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||
os += fmt::format(
|
||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||
} else {
|
||||
os += fmt::format(
|
||||
" constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++);
|
||||
}
|
||||
if (dynamic_dims) {
|
||||
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
|
||||
@@ -92,13 +84,14 @@ inline void build_kernel(
|
||||
os += " uint3 pos [[thread_position_in_grid]],\n";
|
||||
os += " uint3 grid [[threads_per_grid]]) {\n";
|
||||
|
||||
std::string idx_type = use_big_index ? "int64_t" : "uint";
|
||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||
if (contiguous && use_big_index) {
|
||||
// This is only used for contiguous kernels which don't have
|
||||
// a third grid dimension
|
||||
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
|
||||
os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n";
|
||||
} else if (contiguous) {
|
||||
os += " uint index = N_ * pos.x;\n";
|
||||
} else if (work_per_thread > 1) {
|
||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||
os += fmt::format(
|
||||
" int xshape = output_shape[{0}];\n",
|
||||
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
|
||||
@@ -110,6 +103,9 @@ inline void build_kernel(
|
||||
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||
idx_type);
|
||||
}
|
||||
if (work_per_thread > 1 && contiguous) {
|
||||
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
|
||||
}
|
||||
|
||||
// Read constant / contiguous inputs in tmps
|
||||
std::vector<array> nc_inputs;
|
||||
@@ -117,7 +113,7 @@ inline void build_kernel(
|
||||
auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
auto type_str = get_type_string(x.dtype());
|
||||
std::ostringstream ss;
|
||||
print_constant(ss, x);
|
||||
@@ -193,7 +189,7 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Open per-thread loop
|
||||
if (work_per_thread > 1) {
|
||||
if (work_per_thread > 1 && !contiguous) {
|
||||
os +=
|
||||
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||
}
|
||||
@@ -263,15 +259,11 @@ inline void build_kernel(
|
||||
void Compiled::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Make the name for the kernel library
|
||||
if (kernel_lib_.empty()) {
|
||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||
}
|
||||
|
||||
// Get the kernel if someone else built it already
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto lib = d.get_library(kernel_lib_, [&]() {
|
||||
int work_per_thread = get_work_per_thread(outputs_[0].dtype());
|
||||
std::string kernel = metal::utils();
|
||||
concatenate(
|
||||
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
|
||||
@@ -281,21 +273,38 @@ void Compiled::eval_gpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false);
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ 1);
|
||||
if (work_per_thread > 1) {
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous_n",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
is_constant_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ work_per_thread);
|
||||
}
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous_large",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ true);
|
||||
/* use_big_index = */ true,
|
||||
/* work_per_thread = */ work_per_thread);
|
||||
for (int i = 1; i < 8; i++) {
|
||||
build_kernel(
|
||||
kernel,
|
||||
@@ -303,7 +312,7 @@ void Compiled::eval_gpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false,
|
||||
@@ -316,7 +325,7 @@ void Compiled::eval_gpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false,
|
||||
@@ -330,7 +339,7 @@ void Compiled::eval_gpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true,
|
||||
@@ -342,7 +351,7 @@ void Compiled::eval_gpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true,
|
||||
@@ -351,81 +360,32 @@ void Compiled::eval_gpu(
|
||||
return kernel;
|
||||
});
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& output_shape = outputs[0].shape();
|
||||
auto contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
std::vector<Strides> initial_strides;
|
||||
initial_strides.push_back(outputs[0].strides());
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
if (!contiguous) {
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
auto [contiguous, shape, strides] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Skip scalar inputs.
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
Strides xstrides;
|
||||
int j = 0;
|
||||
for (; j < output_shape.size() - x.ndim(); j++) {
|
||||
if (output_shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (output_shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
initial_strides.push_back(std::move(xstrides));
|
||||
}
|
||||
std::tie(shape, strides) =
|
||||
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
|
||||
}
|
||||
|
||||
bool large;
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
for (auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
large = (max_size > UINT32_MAX);
|
||||
} else {
|
||||
size_t max_size = 0;
|
||||
for (auto& o : outputs) {
|
||||
max_size = std::max(max_size, o.size());
|
||||
}
|
||||
large = (max_size > UINT32_MAX);
|
||||
}
|
||||
// Whether to use large index.
|
||||
bool large = compiled_use_large_index(inputs, outputs, contiguous);
|
||||
|
||||
// Get the kernel from the lib
|
||||
int ndim = shape.size();
|
||||
bool dynamic = ndim >= 8;
|
||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||
int work_per_thread = 1;
|
||||
if (!contiguous) {
|
||||
if (dynamic) {
|
||||
kernel_name += "dynamic";
|
||||
} else {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
}
|
||||
work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
|
||||
} else {
|
||||
work_per_thread =
|
||||
get_work_per_thread(outputs[0].dtype(), outputs[0].data_size());
|
||||
if (work_per_thread > 1 && !large) {
|
||||
kernel_name += "_n";
|
||||
}
|
||||
}
|
||||
if (large) {
|
||||
kernel_name += "_large";
|
||||
@@ -439,7 +399,7 @@ void Compiled::eval_gpu(
|
||||
int stride_idx = 1; // idx 0 is the output strides
|
||||
Strides in_strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
@@ -456,8 +416,7 @@ void Compiled::eval_gpu(
|
||||
compute_encoder.set_vector_bytes(in_strides, cnt++);
|
||||
}
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
|
||||
// Put the outputs in
|
||||
for (auto& x : outputs) {
|
||||
@@ -466,8 +425,14 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Put the output shape and strides in
|
||||
if (!contiguous) {
|
||||
compute_encoder.set_vector_bytes(strides[0], cnt++);
|
||||
compute_encoder.set_vector_bytes(shape, cnt++);
|
||||
} else {
|
||||
auto size = outputs[0].data_size();
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(size, cnt++);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(size, cnt++);
|
||||
}
|
||||
}
|
||||
|
||||
// Put the number of dims in if it is dynamic
|
||||
@@ -477,19 +442,18 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Launch the kernel
|
||||
if (contiguous) {
|
||||
size_t nthreads = outputs[0].data_size();
|
||||
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
|
||||
MTL::Size grid_dims = large
|
||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||
? get_2d_grid_dims(
|
||||
outputs[0].shape(), outputs[0].strides(), work_per_thread)
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = outputs[0].size() / (dim0 * dim1);
|
||||
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
int pow2;
|
||||
|
@@ -1,11 +1,10 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
@@ -178,83 +177,6 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
/*copies = */ copies);
|
||||
}
|
||||
|
||||
void conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
int groups,
|
||||
bool flip) {
|
||||
// Make conv params
|
||||
MLXConvParams<1> conv_params{
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in.shape(2)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
|
||||
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
|
||||
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
|
||||
/* const int str[NDIM] = */ {wt_strides[0]},
|
||||
/* const int pad[NDIM] = */ {padding[0]},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
||||
/* const int idil[NDIM] = */ {in_dilation[0]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
||||
/* const int groups = */ groups,
|
||||
/* const bool flip = */ flip};
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
if (groups > 1) {
|
||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
void slow_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
int bm = 16, bn = 8;
|
||||
int tm = 4, tn = 4;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn
|
||||
<< "_tm" << tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
|
||||
|
||||
size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm);
|
||||
size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn);
|
||||
size_t grid_dim_z = conv_params.N;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bm, bn, 1);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -469,6 +391,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
// Get channel iteration info
|
||||
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
||||
int gemm_k_iters = channel_k_iters;
|
||||
bool align_C = conv_params.C % bk == 0;
|
||||
|
||||
// Fix host side helper params
|
||||
int sign = (conv_params.flip ? -1 : 1);
|
||||
@@ -497,14 +420,33 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
/* const int swizzle_log = */ swizzle_log};
|
||||
|
||||
// Determine kernel
|
||||
std::ostringstream kname;
|
||||
kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm
|
||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
concatenate(
|
||||
kname,
|
||||
"implicit_gemm_conv_2d_general_",
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn);
|
||||
std::string hash_name;
|
||||
hash_name.reserve(64);
|
||||
concatenate(hash_name, kname, "_alC_", align_C);
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_C, MTL::DataType::DataTypeBool, 200},
|
||||
};
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel =
|
||||
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
|
||||
auto kernel = get_steel_conv_general_kernel(
|
||||
d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
@@ -755,7 +697,7 @@ void depthwise_conv_2D_gpu(
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@@ -771,6 +713,143 @@ void depthwise_conv_2D_gpu(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void dispatch_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params,
|
||||
std::vector<array>& copies) {
|
||||
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
|
||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||
|
||||
if (is_idil_one && conv_params.groups > 1) {
|
||||
const int C_per_group = conv_params.C / conv_params.groups;
|
||||
const int O_per_group = conv_params.O / conv_params.groups;
|
||||
|
||||
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
|
||||
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
|
||||
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
|
||||
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
|
||||
conv_params.wt_strides[1] == conv_params.wS[1] &&
|
||||
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
|
||||
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
// Direct to winograd conv
|
||||
bool inp_large =
|
||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096;
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
||||
bool out_large =
|
||||
(conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256;
|
||||
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
||||
channels_large) {
|
||||
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
}
|
||||
|
||||
// Direct to implicit gemm conv
|
||||
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
|
||||
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) {
|
||||
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
else {
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
void conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
int groups,
|
||||
bool flip,
|
||||
std::vector<array>& copies) {
|
||||
bool is_idil_one = in_dilation[0] == 1;
|
||||
int C = in.shape(2);
|
||||
int O = wt.shape(0);
|
||||
const int C_per_group = in.shape(2) / groups;
|
||||
const int O_per_group = wt.shape(0) / groups;
|
||||
|
||||
// Direct to implicit gemm conv
|
||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||
MLXConvParams<2> conv_params{
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
/* const int C = */ C,
|
||||
/* const int O = */ O,
|
||||
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1)), 1},
|
||||
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1)), 1},
|
||||
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1)), 1},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], 1},
|
||||
/* const int pad[NDIM] = */ {padding[0], 0},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0], 1},
|
||||
/* const int idil[NDIM] = */ {in_dilation[0], 1},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], 0, in.strides()[2]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], 0, out.strides()[2]},
|
||||
/* const int groups = */ groups,
|
||||
/* const bool flip = */ flip};
|
||||
|
||||
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
return;
|
||||
}
|
||||
|
||||
// Make conv params
|
||||
MLXConvParams<1> conv_params{
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in.shape(2)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
|
||||
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
|
||||
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
|
||||
/* const int str[NDIM] = */ {wt_strides[0]},
|
||||
/* const int pad[NDIM] = */ {padding[0]},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
||||
/* const int idil[NDIM] = */ {in_dilation[0]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
||||
/* const int groups = */ groups,
|
||||
/* const bool flip = */ flip};
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
if (groups > 1) {
|
||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
void conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -808,57 +887,7 @@ void conv_2D_gpu(
|
||||
/* const int groups = */ groups,
|
||||
/* const bool flip = */ flip,
|
||||
};
|
||||
|
||||
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
|
||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||
|
||||
if (is_idil_one && groups > 1) {
|
||||
const int C_per_group = conv_params.C / groups;
|
||||
const int O_per_group = conv_params.O / groups;
|
||||
|
||||
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
|
||||
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
|
||||
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
|
||||
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
|
||||
conv_params.wt_strides[1] == conv_params.wS[1] &&
|
||||
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
|
||||
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
// Direct to winograd conv
|
||||
bool inp_large =
|
||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
||||
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
||||
channels_large) {
|
||||
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
}
|
||||
|
||||
// Direct to implicit gemm conv
|
||||
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
|
||||
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
|
||||
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
else {
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
}
|
||||
|
||||
void conv_3D_gpu(
|
||||
@@ -952,7 +981,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
padding_lo_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
@@ -967,7 +996,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
padding_lo_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
@@ -983,12 +1012,13 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
padding_lo_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups_,
|
||||
flip_);
|
||||
flip_,
|
||||
copies);
|
||||
}
|
||||
// Throw error
|
||||
else {
|
||||
|
@@ -1,35 +1,15 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
bool donated = set_copy_output_data(in, out, ctype);
|
||||
if (donated && in.dtype() == out.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu_inplace(in, out, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
@@ -75,10 +55,10 @@ void copy_gpu_inplace(
|
||||
std::string kernel_name;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kernel_name = (large ? "s2" : "s");
|
||||
kernel_name = large ? "s2" : "s";
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kernel_name = (large ? "v2" : "v");
|
||||
kernel_name = large ? "v2" : "v";
|
||||
break;
|
||||
case CopyType::General:
|
||||
kernel_name = "g";
|
||||
@@ -104,6 +84,11 @@ void copy_gpu_inplace(
|
||||
"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
|
||||
if (work_per_thread > 1) {
|
||||
kernel_name += "n";
|
||||
}
|
||||
}
|
||||
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
||||
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
|
||||
@@ -165,48 +150,33 @@ void copy_gpu_inplace(
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims;
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
|
||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(out.data_size(), 2);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
}
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& i_strides,
|
||||
int64_t i_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
||||
}
|
||||
|
||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
bool large = out.data_size() > UINT32_MAX;
|
||||
int work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
|
||||
auto& d = metal::device(s.device);
|
||||
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
|
||||
type_to_name(val) + type_to_name(out);
|
||||
std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s");
|
||||
concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out));
|
||||
auto kernel = get_copy_kernel(d, kernel_name, val, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
@@ -215,13 +185,19 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t nthreads = out.data_size();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims;
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
|
||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(out.data_size(), 2);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
}
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
|
@@ -1,12 +1,326 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include <iostream>
|
||||
#include <regex>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
struct CustomKernelCache {
|
||||
std::unordered_map<std::string, std::string> libraries;
|
||||
};
|
||||
|
||||
static CustomKernelCache& cache() {
|
||||
static CustomKernelCache cache_;
|
||||
return cache_;
|
||||
};
|
||||
|
||||
std::string write_signature(
|
||||
std::string func_name,
|
||||
const std::string& header,
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<std::string>& attributes,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
bool atomic_outputs) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 16384);
|
||||
kernel_source += header;
|
||||
// Auto-generate a function signature based on `template_args`
|
||||
// and the dtype/shape of the arrays passed as `inputs`.
|
||||
if (!template_args.empty()) {
|
||||
kernel_source += "template <";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
std::string param_type;
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
param_type = "int";
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
param_type = "bool";
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
param_type = "typename";
|
||||
}
|
||||
if (i > 0) {
|
||||
kernel_source += ", ";
|
||||
}
|
||||
kernel_source += param_type;
|
||||
kernel_source += " ";
|
||||
kernel_source += name;
|
||||
i++;
|
||||
}
|
||||
kernel_source += ">\n";
|
||||
}
|
||||
kernel_source += "[[kernel]] void ";
|
||||
kernel_source += func_name;
|
||||
kernel_source += "(\n";
|
||||
|
||||
int index = 0;
|
||||
constexpr int max_constant_array_size = 8;
|
||||
// Add inputs
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
const auto& name = input_names[i];
|
||||
const auto& arr = inputs[i];
|
||||
auto dtype = get_type_string(arr.dtype());
|
||||
std::string location =
|
||||
arr.size() < max_constant_array_size ? "constant" : "device";
|
||||
std::string ref = arr.ndim() == 0 ? "&" : "*";
|
||||
kernel_source += " const ";
|
||||
kernel_source += location;
|
||||
kernel_source += " ";
|
||||
kernel_source += dtype;
|
||||
kernel_source += ref;
|
||||
kernel_source += " ";
|
||||
kernel_source += name;
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]],\n";
|
||||
index++;
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
kernel_source +=
|
||||
(" const constant int* " + name + "_shape [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
kernel_source +=
|
||||
(" const constant int64_t* " + name + "_strides [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
kernel_source +=
|
||||
(" const constant int& " + name + "_ndim [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add outputs
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
const auto& name = output_names[i];
|
||||
const auto& dtype = output_dtypes[i];
|
||||
kernel_source += " device ";
|
||||
auto type_string = get_type_string(dtype);
|
||||
if (atomic_outputs) {
|
||||
kernel_source += "atomic<";
|
||||
}
|
||||
kernel_source += type_string;
|
||||
if (atomic_outputs) {
|
||||
kernel_source += ">";
|
||||
}
|
||||
kernel_source += "* ";
|
||||
kernel_source += name;
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]]";
|
||||
if (index < inputs.size() + output_names.size() - 1 ||
|
||||
attributes.size() > 0) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
index++;
|
||||
}
|
||||
|
||||
index = 0;
|
||||
for (const auto& attr : attributes) {
|
||||
kernel_source += attr;
|
||||
if (index < attributes.size() - 1) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
index++;
|
||||
}
|
||||
kernel_source += source;
|
||||
kernel_source += "\n}\n";
|
||||
return kernel_source;
|
||||
}
|
||||
|
||||
std::string write_template(
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
|
||||
std::ostringstream template_def;
|
||||
template_def << "<";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
if (i > 0) {
|
||||
template_def << ", ";
|
||||
}
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
template_def << std::get<int>(arg);
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
template_def << std::get<bool>(arg);
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
template_def << get_type_string(std::get<Dtype>(arg));
|
||||
}
|
||||
i++;
|
||||
}
|
||||
template_def << ">";
|
||||
return template_def.str();
|
||||
}
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header /* = "" */,
|
||||
bool ensure_row_contiguous /* = true */,
|
||||
bool atomic_outputs /* = false */) {
|
||||
if (output_names.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
|
||||
{"dispatch_quadgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_simdgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_threads_per_threadgroup", "uint3"},
|
||||
{"grid_origin", "uint3"},
|
||||
{"grid_size", "uint3"},
|
||||
{"quadgroup_index_in_threadgroup", "uint"},
|
||||
{"quadgroups_per_threadgroup", "uint"},
|
||||
{"simdgroup_index_in_threadgroup", "uint"},
|
||||
{"simdgroups_per_threadgroup", "uint"},
|
||||
{"thread_execution_width", "uint"},
|
||||
{"thread_index_in_quadgroup", "uint"},
|
||||
{"thread_index_in_simdgroup", "uint"},
|
||||
{"thread_index_in_threadgroup", "uint"},
|
||||
{"thread_position_in_grid", "uint3"},
|
||||
{"thread_position_in_threadgroup", "uint3"},
|
||||
{"threadgroup_position_in_grid", "uint3"},
|
||||
{"threadgroups_per_grid", "uint3"},
|
||||
{"threads_per_grid", "uint3"},
|
||||
{"threads_per_simdgroup", "uint"},
|
||||
{"threads_per_threadgroup", "uint3"},
|
||||
};
|
||||
|
||||
std::vector<std::string> attributes;
|
||||
for (const auto& [attr, dtype] : metal_attributes) {
|
||||
if (source.find(attr) != std::string::npos) {
|
||||
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
|
||||
}
|
||||
}
|
||||
|
||||
return [=,
|
||||
shape_infos = std::move(shape_infos),
|
||||
attributes = std::move(attributes)](
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<Shape>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>&
|
||||
template_args = {},
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s_ = {}) {
|
||||
if (inputs.size() != input_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `inputs` to have size "
|
||||
<< input_names.size() << " but got size " << inputs.size() << "."
|
||||
<< std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_shapes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_shapes` to have size "
|
||||
<< output_names.size() << " but got size " << output_shapes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_dtypes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_dtypes` to have size "
|
||||
<< output_names.size() << " but got size " << output_dtypes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
|
||||
}
|
||||
|
||||
std::string kernel_name = "custom_kernel_" + name;
|
||||
std::string template_def = "";
|
||||
if (!template_args.empty()) {
|
||||
std::regex disallowed_chars("\\<|\\>|(, )");
|
||||
template_def = write_template(template_args);
|
||||
auto template_hash =
|
||||
std::regex_replace(template_def, disallowed_chars, "_");
|
||||
template_hash.pop_back();
|
||||
kernel_name += "_";
|
||||
kernel_name += template_hash;
|
||||
}
|
||||
|
||||
std::string kernel_source = write_signature(
|
||||
kernel_name,
|
||||
header,
|
||||
source,
|
||||
input_names,
|
||||
inputs,
|
||||
output_names,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
attributes,
|
||||
shape_infos,
|
||||
atomic_outputs);
|
||||
|
||||
if (!template_args.empty()) {
|
||||
template_def = kernel_name + template_def;
|
||||
kernel_source += "\ntemplate [[host_name(\"";
|
||||
kernel_source += kernel_name;
|
||||
kernel_source += "\")]] [[kernel]] decltype(";
|
||||
kernel_source += template_def;
|
||||
kernel_source += ") ";
|
||||
kernel_source += template_def;
|
||||
kernel_source += ";\n";
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << name << "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
|
||||
return array::make_arrays(
|
||||
std::move(output_shapes),
|
||||
std::move(output_dtypes),
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
std::move(kernel_name),
|
||||
std::move(kernel_source),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value),
|
||||
std::move(inputs));
|
||||
};
|
||||
}
|
||||
|
||||
void CustomKernel::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
@@ -39,9 +353,23 @@ void CustomKernel::eval_gpu(
|
||||
}
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = name_;
|
||||
auto lib =
|
||||
d.get_library(lib_name, [this] { return metal::utils() + source_; });
|
||||
|
||||
{
|
||||
// Clear kernels from the device library cache if needed
|
||||
auto& kernel_cache = cache();
|
||||
if (auto it = kernel_cache.libraries.find(name_);
|
||||
it != kernel_cache.libraries.end()) {
|
||||
if (it->second != source_) {
|
||||
auto& d = metal::device(s.device);
|
||||
d.clear_library(name_);
|
||||
it->second = source_;
|
||||
}
|
||||
} else {
|
||||
kernel_cache.libraries.emplace(name_, source_);
|
||||
}
|
||||
}
|
||||
|
||||
auto lib = d.get_library(name_, [this] { return metal::utils() + source_; });
|
||||
auto kernel = d.get_kernel(name_, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
@@ -73,6 +401,16 @@ void CustomKernel::eval_gpu(
|
||||
}
|
||||
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
auto tg_size = tx * ty * tz;
|
||||
auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (tg_size > max_tg_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "Thread group size (" << tg_size << ") is greater than "
|
||||
<< " the maximum allowed threads per threadgroup (" << max_tg_size
|
||||
<< ").";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
MTL::Size group_dims =
|
||||
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
||||
|
@@ -1,20 +1,20 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <sstream>
|
||||
|
||||
#include <sys/sysctl.h>
|
||||
|
||||
#define NS_PRIVATE_IMPLEMENTATION
|
||||
#define CA_PRIVATE_IMPLEMENTATION
|
||||
#define MTL_PRIVATE_IMPLEMENTATION
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
namespace {
|
||||
@@ -66,8 +66,8 @@ MTL::Library* try_load_bundle(
|
||||
if (bundle != nullptr) {
|
||||
std::string resource_path =
|
||||
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
|
||||
lib_name + ".metallib" auto [lib, error] =
|
||||
load_library_from_path(device, resource_path.c_str());
|
||||
lib_name + ".metallib";
|
||||
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
@@ -79,12 +79,18 @@ MTL::Library* try_load_bundle(
|
||||
// Firstly, search for the metallib in the same path as this binary
|
||||
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
|
||||
MTL::Device* device,
|
||||
const std::string& lib_name) {
|
||||
std::string lib_path = get_colocated_mtllib_path(lib_name);
|
||||
if (lib_path.size() != 0) {
|
||||
return load_library_from_path(device, lib_path.c_str());
|
||||
const std::string& relative_path) {
|
||||
std::string binary_dir = get_binary_directory();
|
||||
if (binary_dir.size() == 0) {
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
return {nullptr, nullptr};
|
||||
|
||||
auto path = fs::path(binary_dir) / relative_path;
|
||||
if (!path.has_extension()) {
|
||||
path.replace_extension(".metallib");
|
||||
}
|
||||
|
||||
return load_library_from_path(device, path.c_str());
|
||||
}
|
||||
|
||||
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
@@ -99,7 +105,7 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
auto bundles = NS::Bundle::allBundles();
|
||||
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
|
||||
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
|
||||
library = try_load_bundle(device, bundle->resourceURL());
|
||||
library = try_load_bundle(device, bundle->resourceURL(), lib_name);
|
||||
if (library != nullptr) {
|
||||
return {library, nullptr};
|
||||
}
|
||||
@@ -109,33 +115,34 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
}
|
||||
|
||||
MTL::Library* load_default_library(MTL::Device* device) {
|
||||
NS::Error *error1, *error2, *error3;
|
||||
NS::Error* error[4];
|
||||
MTL::Library* lib;
|
||||
// First try the colocated mlx.metallib
|
||||
std::tie(lib, error1) = load_colocated_library(device, "mlx");
|
||||
std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
std::tie(lib, error[1]) = load_colocated_library(device, "Resources/mlx");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Then try default.metallib in a SwiftPM bundle if we have one
|
||||
std::tie(lib, error2) = load_swiftpm_library(device, "default");
|
||||
std::tie(lib, error[2]) = load_swiftpm_library(device, "default");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Finally try default_mtllib_path
|
||||
std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path);
|
||||
std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path);
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the default metallib. ";
|
||||
if (error1 != nullptr) {
|
||||
msg << error1->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
if (error2 != nullptr) {
|
||||
msg << error2->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
if (error3 != nullptr) {
|
||||
msg << error3->localizedDescription()->utf8String() << " ";
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (error[i] != nullptr) {
|
||||
msg << error[i]->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
@@ -156,6 +163,7 @@ MTL::Library* load_library(
|
||||
<< error->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
|
||||
// We have been given a path so try to load from lib_path / lib_name.metallib
|
||||
@@ -168,6 +176,7 @@ MTL::Library* load_library(
|
||||
<< "> with error " << error->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Try to load the colocated library
|
||||
@@ -188,8 +197,8 @@ MTL::Library* load_library(
|
||||
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the metallib " << lib_name << ".metallib. "
|
||||
<< "We attempted to load it from <" << get_colocated_mtllib_path(lib_name)
|
||||
<< ">";
|
||||
<< "We attempted to load it from <" << get_binary_directory() << "/"
|
||||
<< lib_name << ".metallib" << ">";
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
msg << " and from the Swift PM bundle.";
|
||||
#endif
|
||||
@@ -286,7 +295,7 @@ void CommandEncoder::barrier() {
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
library_map_ = {{"mlx", load_default_library(device_)}};
|
||||
default_library_ = load_default_library(device_);
|
||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||
auto arch = arch_.back();
|
||||
switch (arch) {
|
||||
@@ -317,11 +326,11 @@ Device::Device() {
|
||||
|
||||
Device::~Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
}
|
||||
for (auto& l : library_map_) {
|
||||
l.second->release();
|
||||
for (auto& [l, kernel_map] : library_kernels_) {
|
||||
l->release();
|
||||
for (auto& [_, k] : kernel_map) {
|
||||
k->release();
|
||||
}
|
||||
}
|
||||
stream_map_.clear();
|
||||
device_->release();
|
||||
@@ -465,13 +474,24 @@ CommandEncoder& Device::get_command_encoder(int index) {
|
||||
return *stream.encoder;
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
|
||||
library_map_.insert({lib_name, new_lib});
|
||||
MTL::Library* Device::get_library(
|
||||
const std::string& name,
|
||||
const std::string& path /* = "" */) {
|
||||
{
|
||||
std::shared_lock rlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock wlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto new_lib = load_library(device_, name, path.c_str());
|
||||
library_map_.insert({name, new_lib});
|
||||
return new_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::build_library_(const std::string& source_string) {
|
||||
@@ -640,6 +660,19 @@ MTL::Library* Device::get_library(
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
void Device::clear_library(const std::string& name) {
|
||||
std::unique_lock wlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
auto kernel_map_it = library_kernels_.find(it->second);
|
||||
for (auto& [_, kernel] : kernel_map_it->second) {
|
||||
kernel->release();
|
||||
}
|
||||
library_kernels_.erase(kernel_map_it);
|
||||
it->second->release();
|
||||
library_map_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
MTL::LinkedFunctions* Device::get_linked_functions_(
|
||||
const std::vector<MTL::Function*>& funcs) {
|
||||
if (funcs.empty()) {
|
||||
@@ -670,6 +703,7 @@ MTL::ComputePipelineState* Device::get_kernel_(
|
||||
std::unique_lock wlock(kernel_mtx_);
|
||||
|
||||
// Try loading again to avoid loading twice
|
||||
auto& kernel_map_ = library_kernels_[mtl_lib];
|
||||
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
@@ -704,6 +738,7 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
std::shared_lock lock(kernel_mtx_);
|
||||
|
||||
// Look for cached kernel
|
||||
auto& kernel_map_ = library_kernels_[mtl_lib];
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
@@ -713,23 +748,11 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name /* = "mlx" */,
|
||||
const std::string& hash_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */,
|
||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
|
||||
{
|
||||
// Multiple readers allowed
|
||||
std::shared_lock lock(kernel_mtx_);
|
||||
|
||||
// Look for cached kernel
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib = get_library_(lib_name);
|
||||
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||
return get_kernel(
|
||||
base_name, default_library_, hash_name, func_consts, linked_functions);
|
||||
}
|
||||
|
||||
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
|
||||
@@ -760,42 +783,4 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
|
||||
NS::AutoreleasePool::alloc()->init(), dtor);
|
||||
}
|
||||
|
||||
void new_stream(Stream stream) {
|
||||
if (stream.device == mlx::core::Device::gpu) {
|
||||
device(stream.device).new_queue(stream.index);
|
||||
}
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
||||
device_info() {
|
||||
auto init_device_info = []()
|
||||
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto raw_device = device(default_device()).mtl_device();
|
||||
auto name = std::string(raw_device->name()->utf8String());
|
||||
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
||||
|
||||
size_t memsize = 0;
|
||||
size_t length = sizeof(memsize);
|
||||
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
|
||||
|
||||
size_t rsrc_limit = 0;
|
||||
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
|
||||
if (rsrc_limit == 0) {
|
||||
rsrc_limit = 499000;
|
||||
}
|
||||
|
||||
return {
|
||||
{"device_name", name},
|
||||
{"architecture", arch},
|
||||
{"max_buffer_length", raw_device->maxBufferLength()},
|
||||
{"max_recommended_working_set_size",
|
||||
raw_device->recommendedMaxWorkingSetSize()},
|
||||
{"memory_size", memsize},
|
||||
{"resource_limit", rsrc_limit}};
|
||||
};
|
||||
static auto device_info_ = init_device_info();
|
||||
return device_info_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -21,18 +21,14 @@ namespace mlx::core::metal {
|
||||
|
||||
// Note, this function must be left inline in a header so that it is not
|
||||
// dynamically linked.
|
||||
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
inline std::string get_binary_directory() {
|
||||
Dl_info info;
|
||||
std::string mtllib_path;
|
||||
std::string lib_ext = lib_name + ".metallib";
|
||||
|
||||
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
||||
std::string directory;
|
||||
int success = dladdr((void*)get_binary_directory, &info);
|
||||
if (success) {
|
||||
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
||||
mtllib_path = mtllib.c_str();
|
||||
directory = fs::path(info.dli_fname).remove_filename().c_str();
|
||||
}
|
||||
|
||||
return mtllib_path;
|
||||
return directory;
|
||||
}
|
||||
|
||||
using MTLFCList =
|
||||
@@ -99,6 +95,10 @@ struct CommandEncoder {
|
||||
return enc_->setBytes(&v, sizeof(T), idx);
|
||||
}
|
||||
|
||||
void set_threadgroup_memory_length(size_t length, int idx) {
|
||||
enc_->setThreadgroupMemoryLength(length, idx);
|
||||
}
|
||||
|
||||
ConcurrentContext start_concurrent() {
|
||||
return ConcurrentContext(*this);
|
||||
}
|
||||
@@ -187,14 +187,16 @@ class Device {
|
||||
CommandEncoder& get_command_encoder(int index);
|
||||
void end_encoding(int index);
|
||||
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path = "");
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const std::string& path = "");
|
||||
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const std::function<std::string(void)>& builder);
|
||||
|
||||
void clear_library(const std::string& name);
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
@@ -204,7 +206,6 @@ class Device {
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name = "mlx",
|
||||
const std::string& hash_name = "",
|
||||
const MTLFCList& func_consts = {},
|
||||
const std::vector<MTL::Function*>& linked_functions = {});
|
||||
@@ -258,10 +259,13 @@ class Device {
|
||||
std::unordered_map<int32_t, DeviceStream> stream_map_;
|
||||
|
||||
std::shared_mutex kernel_mtx_;
|
||||
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
||||
|
||||
std::shared_mutex library_mtx_;
|
||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||
MTL::Library* default_library_;
|
||||
std::unordered_map<
|
||||
MTL::Library*,
|
||||
std::unordered_map<std::string, MTL::ComputePipelineState*>>
|
||||
library_kernels_;
|
||||
const MTL::ResidencySet* residency_set_{nullptr};
|
||||
std::string arch_;
|
||||
int max_ops_per_buffer_;
|
||||
@@ -270,4 +274,6 @@ class Device {
|
||||
|
||||
Device& device(mlx::core::Device);
|
||||
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -4,7 +4,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user