mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
58 Commits
fft
...
c35f4d089a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c35f4d089a | ||
|
|
8590c0941e | ||
|
|
095163b8d1 | ||
|
|
99c33d011d | ||
|
|
62fecf3e13 | ||
|
|
7c4eb5d03e | ||
|
|
bae9a6b404 | ||
|
|
004c1d8ef2 | ||
|
|
7ebb2e0193 | ||
|
|
9ce77798b1 | ||
|
|
f8bad60609 | ||
|
|
5866b3857b | ||
|
|
1ca616844b | ||
|
|
2e8cf0b450 | ||
|
|
24f89173d1 | ||
|
|
c6a20b427a | ||
|
|
a5ac9244c4 | ||
|
|
c763fe1be0 | ||
|
|
52dc8c8cd5 | ||
|
|
aede70e81d | ||
|
|
85a8beb5e4 | ||
|
|
0bb89e9e5f | ||
|
|
5685ceb3c7 | ||
|
|
0408ba0a76 | ||
|
|
cbad6c3093 | ||
|
|
1b021f6984 | ||
|
|
95b7551d65 | ||
|
|
db5a7c6192 | ||
|
|
6ef2f67e7f | ||
|
|
f76ee1ffd2 | ||
|
|
54a71f270a | ||
|
|
55b4062dd8 | ||
|
|
79071bfba4 | ||
|
|
7774b87cbd | ||
|
|
35c87741cf | ||
|
|
4cbe605214 | ||
|
|
ab8883dd55 | ||
|
|
eebe73001a | ||
|
|
0359bf02c9 | ||
|
|
237f9e58a8 | ||
|
|
8576e6fe36 | ||
|
|
0654543dcc | ||
|
|
48ef3e74e2 | ||
|
|
7d4b378952 | ||
|
|
7ff5c41e06 | ||
|
|
602f43e3d1 | ||
|
|
a2cadb8218 | ||
|
|
c1eb9d05d9 | ||
|
|
cf6c939e86 | ||
|
|
130df35e1b | ||
|
|
0751263dec | ||
|
|
eca2f3eb97 | ||
|
|
3aa9cf3f9e | ||
|
|
8f3d208dce | ||
|
|
caaa3f1f8c | ||
|
|
659a51919f | ||
|
|
6661387066 | ||
|
|
a7fae8a176 |
@@ -212,6 +212,29 @@ jobs:
|
|||||||
METAL_DEBUG_ERROR_MODE=0 \
|
METAL_DEBUG_ERROR_MODE=0 \
|
||||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||||
|
|
||||||
|
cuda_build_and_test:
|
||||||
|
machine:
|
||||||
|
image: linux-cuda-12:default
|
||||||
|
resource_class: gpu.nvidia.small.gen2
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Install Python package
|
||||||
|
command: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||||
|
python -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
- run:
|
||||||
|
name: Run Python tests
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||||
|
|
||||||
build_release:
|
build_release:
|
||||||
parameters:
|
parameters:
|
||||||
python_version:
|
python_version:
|
||||||
@@ -348,6 +371,7 @@ workflows:
|
|||||||
parameters:
|
parameters:
|
||||||
macosx_deployment_target: ["13.5", "14.0"]
|
macosx_deployment_target: ["13.5", "14.0"]
|
||||||
- linux_build_and_test
|
- linux_build_and_test
|
||||||
|
- cuda_build_and_test
|
||||||
- build_documentation
|
- build_documentation
|
||||||
|
|
||||||
build_pypi_release:
|
build_pypi_release:
|
||||||
@@ -455,6 +479,8 @@ workflows:
|
|||||||
macosx_deployment_target: ["13.5", "14.0"]
|
macosx_deployment_target: ["13.5", "14.0"]
|
||||||
- linux_build_and_test:
|
- linux_build_and_test:
|
||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
|
- cuda_build_and_test:
|
||||||
|
requires: [ hold ]
|
||||||
nightly_build:
|
nightly_build:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
|
|||||||
@@ -231,6 +231,9 @@ target_include_directories(
|
|||||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||||
$<INSTALL_INTERFACE:include>)
|
$<INSTALL_INTERFACE:include>)
|
||||||
|
|
||||||
|
# Do not add mlx_EXPORTS define for shared library.
|
||||||
|
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||||
|
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
fmt
|
fmt
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
|||||||
107
benchmarks/python/conv_unaligned_bench.py
Normal file
107
benchmarks/python/conv_unaligned_bench.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 10
|
||||||
|
N_iter_bench = 100
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
|
out_pt = torch.conv2d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dtype = "float32"
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 21, 3, 3, 128),
|
||||||
|
(4, 32, 32, 21, 3, 3, 37),
|
||||||
|
(4, 32, 32, 370, 3, 3, 370),
|
||||||
|
(4, 32, 32, 370, 7, 7, 128),
|
||||||
|
(2, 320, 640, 21, 7, 7, 21),
|
||||||
|
)
|
||||||
|
for N, H, W, C, kh, kw, O in shapes:
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
@@ -18,51 +20,63 @@ def layer_norm(x, w, b, eps):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
def time_layer_norm():
|
def time_layer_norm(N, dt):
|
||||||
|
L = 1024
|
||||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
mx.eval(x, w, b, y)
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
def layer_norm_loop(g, x, w, b):
|
def layer_norm_loop(f, x, w, b):
|
||||||
|
for _ in range(32):
|
||||||
|
x = f(x, w, b)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
||||||
|
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
||||||
|
|
||||||
|
def layer_norm_grad_loop(g, x, w, b):
|
||||||
gx, gw, gb = x, w, b
|
gx, gw, gb = x, w, b
|
||||||
for _ in range(32):
|
for _ in range(32):
|
||||||
gx, gw, gb = g(gx, gw, gb, y)
|
gx, gw, gb = g(gx, gw, gb, y)
|
||||||
return gx, gw, gb
|
return gx, gw, gb
|
||||||
|
|
||||||
time_fn(layer_norm_loop, g1, x, w, b)
|
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
||||||
time_fn(layer_norm_loop, g2, x, w, b)
|
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
||||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
||||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
||||||
|
|
||||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
g1 = mx.grad(f1, argnums=(0,))
|
g1 = mx.grad(f1, argnums=(0,))
|
||||||
g2 = mx.grad(f2, argnums=(0,))
|
g2 = mx.grad(f2, argnums=(0,))
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
mx.eval(x, w, b, y)
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
def layer_norm_loop(g, x):
|
def layer_norm_grad_x_loop(g, x):
|
||||||
gx = x
|
gx = x
|
||||||
for _ in range(32):
|
for _ in range(32):
|
||||||
gx = g(gx, y)
|
gx = g(gx, y)
|
||||||
return gx
|
return gx
|
||||||
|
|
||||||
time_fn(layer_norm_loop, g1, x)
|
time_fn(layer_norm_grad_x_loop, g1, x)
|
||||||
time_fn(layer_norm_loop, g2, x)
|
time_fn(layer_norm_grad_x_loop, g2, x)
|
||||||
time_fn(layer_norm_loop, mx.compile(g1), x)
|
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
||||||
time_fn(layer_norm_loop, mx.compile(g2), x)
|
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
time_layer_norm()
|
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
||||||
|
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
||||||
|
print(dt, n)
|
||||||
|
time_layer_norm(n, dt)
|
||||||
|
|||||||
@@ -11,13 +11,14 @@ include(CMakeParseArguments)
|
|||||||
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||||
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||||
# files (like headers)
|
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
|
||||||
|
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
|
||||||
#
|
#
|
||||||
# clang format on
|
# clang format on
|
||||||
|
|
||||||
macro(mlx_build_metallib)
|
macro(mlx_build_metallib)
|
||||||
# Parse args
|
# Parse args
|
||||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
|
||||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||||
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||||
|
|
||||||
@@ -26,6 +27,10 @@ macro(mlx_build_metallib)
|
|||||||
|
|
||||||
# Collect compile options
|
# Collect compile options
|
||||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||||
|
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
|
||||||
|
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
|
||||||
|
-frecord-sources)
|
||||||
|
endif()
|
||||||
|
|
||||||
# Prepare metallib build command
|
# Prepare metallib build command
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import mlx.core as mx
|
|||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
project = "MLX"
|
project = "MLX"
|
||||||
copyright = "2023, MLX Contributors"
|
copyright = "2023, Apple"
|
||||||
author = "MLX Contributors"
|
author = "MLX Contributors"
|
||||||
version = ".".join(mx.__version__.split(".")[:3])
|
version = ".".join(mx.__version__.split(".")[:3])
|
||||||
release = version
|
release = version
|
||||||
|
|||||||
@@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
|||||||
Simple Example
|
Simple Example
|
||||||
--------------
|
--------------
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
source = """
|
||||||
source = """
|
uint elem = thread_position_in_grid.x;
|
||||||
uint elem = thread_position_in_grid.x;
|
T tmp = inp[elem];
|
||||||
T tmp = inp[elem];
|
out[elem] = metal::exp(tmp);
|
||||||
out[elem] = metal::exp(tmp);
|
"""
|
||||||
"""
|
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="myexp",
|
name="myexp",
|
||||||
input_names=["inp"],
|
input_names=["inp"],
|
||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def exp_elementwise(a: mx.array):
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
|||||||
b = exp_elementwise(a)
|
b = exp_elementwise(a)
|
||||||
assert mx.allclose(b, mx.exp(a))
|
assert mx.allclose(b, mx.exp(a))
|
||||||
|
|
||||||
|
Every time you make a kernel, a new Metal library is created and possibly
|
||||||
|
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||||
|
:func:`fast.metal_kernel` and then use it many times.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
We are only required to pass the body of the Metal kernel in ``source``.
|
Only pass the body of the Metal kernel in ``source``. The function
|
||||||
|
signature is generated automatically.
|
||||||
|
|
||||||
The full function signature will be generated using:
|
The full function signature will be generated using:
|
||||||
|
|
||||||
@@ -78,44 +86,51 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
|||||||
|
|
||||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||||
|
|
||||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
||||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
||||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
||||||
|
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
||||||
|
dimension should be less than or equal to the corresponding grid dimension.
|
||||||
|
|
||||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
||||||
|
generated code for debugging purposes.
|
||||||
|
|
||||||
Using Shape/Strides
|
Using Shape/Strides
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
||||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
is ``True`` by default. This will copy the array inputs if needed
|
||||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
before the kernel is launched to ensure that the memory layout is row
|
||||||
when indexing.
|
contiguous. Generally this makes writing the kernel easier, since we don't
|
||||||
|
have to worry about gaps or the ordering of the dims when indexing.
|
||||||
|
|
||||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
||||||
input array ``a`` if any are present in ``source``.
|
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
||||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
||||||
|
the right elements for each thread.
|
||||||
|
|
||||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
||||||
|
relying on a copy from ``ensure_row_contiguous``:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||||
|
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||||
|
T tmp = inp[loc];
|
||||||
|
// Output arrays are always row contiguous
|
||||||
|
out[elem] = metal::exp(tmp);
|
||||||
|
"""
|
||||||
|
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="myexp_strided",
|
||||||
|
input_names=["inp"],
|
||||||
|
output_names=["out"],
|
||||||
|
source=source
|
||||||
|
)
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
def exp_elementwise(a: mx.array):
|
||||||
source = """
|
|
||||||
uint elem = thread_position_in_grid.x;
|
|
||||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
|
||||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
|
||||||
T tmp = inp[loc];
|
|
||||||
// Output arrays are always row contiguous
|
|
||||||
out[elem] = metal::exp(tmp);
|
|
||||||
"""
|
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
|
||||||
name="myexp_strided",
|
|
||||||
input_names=["inp"],
|
|
||||||
output_names=["out"],
|
|
||||||
source=source
|
|
||||||
)
|
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def grid_sample_ref(x, grid):
|
def grid_sample_ref(x, grid):
|
||||||
N, H_in, W_in, _ = x.shape
|
N, H_in, W_in, _ = x.shape
|
||||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||||
|
|
||||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||||
|
|
||||||
ix_ne = ix_nw + 1
|
ix_ne = ix_nw + 1
|
||||||
iy_ne = iy_nw
|
iy_ne = iy_nw
|
||||||
|
|
||||||
ix_sw = ix_nw
|
ix_sw = ix_nw
|
||||||
iy_sw = iy_nw + 1
|
iy_sw = iy_nw + 1
|
||||||
|
|
||||||
ix_se = ix_nw + 1
|
ix_se = ix_nw + 1
|
||||||
iy_se = iy_nw + 1
|
iy_se = iy_nw + 1
|
||||||
|
|
||||||
nw = (ix_se - ix) * (iy_se - iy)
|
nw = (ix_se - ix) * (iy_se - iy)
|
||||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||||
se = (ix - ix_nw) * (iy - iy_nw)
|
se = (ix - ix_nw) * (iy - iy_nw)
|
||||||
|
|
||||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||||
|
|
||||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||||
|
|
||||||
I_nw *= mask_nw[..., None]
|
I_nw *= mask_nw[..., None]
|
||||||
I_ne *= mask_ne[..., None]
|
I_ne *= mask_ne[..., None]
|
||||||
I_sw *= mask_sw[..., None]
|
I_sw *= mask_sw[..., None]
|
||||||
I_se *= mask_se[..., None]
|
I_se *= mask_se[..., None]
|
||||||
|
|
||||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
||||||
to write a fast GPU kernel for both the forward and backward passes.
|
to write a fast GPU kernel for both the forward and backward passes.
|
||||||
|
|
||||||
First we'll implement the forward pass as a fused kernel:
|
First we'll implement the forward pass as a fused kernel:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@mx.custom_function
|
source = """
|
||||||
def grid_sample(x, grid):
|
uint elem = thread_position_in_grid.x;
|
||||||
|
int H = x_shape[1];
|
||||||
|
int W = x_shape[2];
|
||||||
|
int C = x_shape[3];
|
||||||
|
int gH = grid_shape[1];
|
||||||
|
int gW = grid_shape[2];
|
||||||
|
|
||||||
assert x.ndim == 4, "`x` must be 4D."
|
int w_stride = C;
|
||||||
assert grid.ndim == 4, "`grid` must be 4D."
|
int h_stride = W * w_stride;
|
||||||
|
int b_stride = H * h_stride;
|
||||||
|
|
||||||
B, _, _, C = x.shape
|
uint grid_idx = elem / C * 2;
|
||||||
_, gN, gM, D = grid.shape
|
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||||
out_shape = (B, gN, gM, C)
|
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
int ix_nw = floor(ix);
|
||||||
|
int iy_nw = floor(iy);
|
||||||
|
|
||||||
source = """
|
int ix_ne = ix_nw + 1;
|
||||||
uint elem = thread_position_in_grid.x;
|
int iy_ne = iy_nw;
|
||||||
int H = x_shape[1];
|
|
||||||
int W = x_shape[2];
|
|
||||||
int C = x_shape[3];
|
|
||||||
int gH = grid_shape[1];
|
|
||||||
int gW = grid_shape[2];
|
|
||||||
|
|
||||||
int w_stride = C;
|
int ix_sw = ix_nw;
|
||||||
int h_stride = W * w_stride;
|
int iy_sw = iy_nw + 1;
|
||||||
int b_stride = H * h_stride;
|
|
||||||
|
|
||||||
uint grid_idx = elem / C * 2;
|
int ix_se = ix_nw + 1;
|
||||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
int iy_se = iy_nw + 1;
|
||||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
|
||||||
|
|
||||||
int ix_nw = floor(ix);
|
T nw = (ix_se - ix) * (iy_se - iy);
|
||||||
int iy_nw = floor(iy);
|
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||||
|
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||||
|
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||||
|
|
||||||
int ix_ne = ix_nw + 1;
|
int batch_idx = elem / C / gH / gW * b_stride;
|
||||||
int iy_ne = iy_nw;
|
int channel_idx = elem % C;
|
||||||
|
int base_idx = batch_idx + channel_idx;
|
||||||
|
|
||||||
int ix_sw = ix_nw;
|
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||||
int iy_sw = iy_nw + 1;
|
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||||
|
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||||
|
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||||
|
|
||||||
int ix_se = ix_nw + 1;
|
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||||
int iy_se = iy_nw + 1;
|
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||||
|
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||||
|
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||||
|
|
||||||
T nw = (ix_se - ix) * (iy_se - iy);
|
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
"""
|
||||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
|
||||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
|
||||||
|
|
||||||
int batch_idx = elem / C / gH / gW * b_stride;
|
kernel = mx.fast.metal_kernel(
|
||||||
int channel_idx = elem % C;
|
name="grid_sample",
|
||||||
int base_idx = batch_idx + channel_idx;
|
input_names=["x", "grid"],
|
||||||
|
output_names=["out"],
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
|
||||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
@mx.custom_function
|
||||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
def grid_sample(x, grid):
|
||||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
|
||||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
|
||||||
|
|
||||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
assert x.ndim == 4, "`x` must be 4D."
|
||||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
assert grid.ndim == 4, "`grid` must be 4D."
|
||||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
|
||||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
|
||||||
|
|
||||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
B, _, _, C = x.shape
|
||||||
"""
|
_, gN, gM, D = grid.shape
|
||||||
kernel = mx.fast.metal_kernel(
|
out_shape = (B, gN, gM, C)
|
||||||
name="grid_sample",
|
|
||||||
input_names=["x", "grid"],
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
output_names=["out"],
|
|
||||||
source=source,
|
outputs = kernel(
|
||||||
)
|
inputs=[x, grid],
|
||||||
outputs = kernel(
|
template=[("T", x.dtype)],
|
||||||
inputs=[x, grid],
|
output_shapes=[out_shape],
|
||||||
template=[("T", x.dtype)],
|
output_dtypes=[x.dtype],
|
||||||
output_shapes=[out_shape],
|
grid=(np.prod(out_shape), 1, 1),
|
||||||
output_dtypes=[x.dtype],
|
threadgroup=(256, 1, 1),
|
||||||
grid=(np.prod(out_shape), 1, 1),
|
)
|
||||||
threadgroup=(256, 1, 1),
|
return outputs[0]
|
||||||
)
|
|
||||||
return outputs[0]
|
|
||||||
|
|
||||||
For a reasonably sized input such as:
|
For a reasonably sized input such as:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
x.shape = (8, 1024, 1024, 64)
|
x.shape = (8, 1024, 1024, 64)
|
||||||
grid.shape = (8, 256, 256, 2)
|
grid.shape = (8, 256, 256, 2)
|
||||||
|
|
||||||
On an M1 Max, we see a big performance improvement:
|
On an M1 Max, we see a big performance improvement:
|
||||||
|
|
||||||
@@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
|
|||||||
Grid Sample VJP
|
Grid Sample VJP
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
||||||
its custom vjp transform so MLX can differentiate it.
|
define its custom vjp transform so MLX can differentiate it.
|
||||||
|
|
||||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
requires a few extra :func:`fast.metal_kernel` features:
|
||||||
|
|
||||||
* ``init_value=0``
|
* ``init_value=0``
|
||||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||||
@@ -299,128 +316,129 @@ We can then implement the backwards pass as follows:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@grid_sample.vjp
|
source = """
|
||||||
def grid_sample_vjp(primals, cotangent, _):
|
uint elem = thread_position_in_grid.x;
|
||||||
x, grid = primals
|
int H = x_shape[1];
|
||||||
B, _, _, C = x.shape
|
int W = x_shape[2];
|
||||||
_, gN, gM, D = grid.shape
|
int C = x_shape[3];
|
||||||
|
// Pad C to the nearest larger simdgroup size multiple
|
||||||
|
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
int gH = grid_shape[1];
|
||||||
|
int gW = grid_shape[2];
|
||||||
|
|
||||||
source = """
|
int w_stride = C;
|
||||||
uint elem = thread_position_in_grid.x;
|
int h_stride = W * w_stride;
|
||||||
int H = x_shape[1];
|
int b_stride = H * h_stride;
|
||||||
int W = x_shape[2];
|
|
||||||
int C = x_shape[3];
|
|
||||||
// Pad C to the nearest larger simdgroup size multiple
|
|
||||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
|
||||||
|
|
||||||
int gH = grid_shape[1];
|
uint grid_idx = elem / C_padded * 2;
|
||||||
int gW = grid_shape[2];
|
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||||
|
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||||
|
|
||||||
int w_stride = C;
|
int ix_nw = floor(ix);
|
||||||
int h_stride = W * w_stride;
|
int iy_nw = floor(iy);
|
||||||
int b_stride = H * h_stride;
|
|
||||||
|
|
||||||
uint grid_idx = elem / C_padded * 2;
|
int ix_ne = ix_nw + 1;
|
||||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
int iy_ne = iy_nw;
|
||||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
|
||||||
|
|
||||||
int ix_nw = floor(ix);
|
int ix_sw = ix_nw;
|
||||||
int iy_nw = floor(iy);
|
int iy_sw = iy_nw + 1;
|
||||||
|
|
||||||
int ix_ne = ix_nw + 1;
|
int ix_se = ix_nw + 1;
|
||||||
int iy_ne = iy_nw;
|
int iy_se = iy_nw + 1;
|
||||||
|
|
||||||
int ix_sw = ix_nw;
|
T nw = (ix_se - ix) * (iy_se - iy);
|
||||||
int iy_sw = iy_nw + 1;
|
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||||
|
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||||
|
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||||
|
|
||||||
int ix_se = ix_nw + 1;
|
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||||
int iy_se = iy_nw + 1;
|
int channel_idx = elem % C_padded;
|
||||||
|
int base_idx = batch_idx + channel_idx;
|
||||||
|
|
||||||
T nw = (ix_se - ix) * (iy_se - iy);
|
T gix = T(0);
|
||||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
T giy = T(0);
|
||||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
if (channel_idx < C) {
|
||||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
int cot_index = elem / C_padded * C + channel_idx;
|
||||||
|
T cot = cotangent[cot_index];
|
||||||
|
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||||
|
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||||
|
|
||||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
T I_nw = x[offset];
|
||||||
int channel_idx = elem % C_padded;
|
gix -= I_nw * (iy_se - iy) * cot;
|
||||||
int base_idx = batch_idx + channel_idx;
|
giy -= I_nw * (ix_se - ix) * cot;
|
||||||
|
}
|
||||||
|
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||||
|
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||||
|
|
||||||
T gix = T(0);
|
T I_ne = x[offset];
|
||||||
T giy = T(0);
|
gix += I_ne * (iy_sw - iy) * cot;
|
||||||
if (channel_idx < C) {
|
giy -= I_ne * (ix - ix_sw) * cot;
|
||||||
int cot_index = elem / C_padded * C + channel_idx;
|
}
|
||||||
T cot = cotangent[cot_index];
|
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
|
||||||
|
|
||||||
T I_nw = x[offset];
|
T I_sw = x[offset];
|
||||||
gix -= I_nw * (iy_se - iy) * cot;
|
gix -= I_sw * (iy - iy_ne) * cot;
|
||||||
giy -= I_nw * (ix_se - ix) * cot;
|
giy += I_sw * (ix_ne - ix) * cot;
|
||||||
}
|
}
|
||||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||||
|
|
||||||
T I_ne = x[offset];
|
T I_se = x[offset];
|
||||||
gix += I_ne * (iy_sw - iy) * cot;
|
gix += I_se * (iy - iy_nw) * cot;
|
||||||
giy -= I_ne * (ix - ix_sw) * cot;
|
giy += I_se * (ix - ix_nw) * cot;
|
||||||
}
|
}
|
||||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
}
|
||||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
|
||||||
|
|
||||||
T I_sw = x[offset];
|
T gix_mult = W / 2;
|
||||||
gix -= I_sw * (iy - iy_ne) * cot;
|
T giy_mult = H / 2;
|
||||||
giy += I_sw * (ix_ne - ix) * cot;
|
|
||||||
}
|
|
||||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
|
||||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
|
||||||
|
|
||||||
T I_se = x[offset];
|
// Reduce across each simdgroup first.
|
||||||
gix += I_se * (iy - iy_nw) * cot;
|
// This is much faster than relying purely on atomics.
|
||||||
giy += I_se * (ix - ix_nw) * cot;
|
gix = simd_sum(gix);
|
||||||
}
|
giy = simd_sum(giy);
|
||||||
}
|
|
||||||
|
|
||||||
T gix_mult = W / 2;
|
if (thread_index_in_simdgroup == 0) {
|
||||||
T giy_mult = H / 2;
|
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||||
|
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="grid_sample_grad",
|
||||||
|
input_names=["x", "grid", "cotangent"],
|
||||||
|
output_names=["x_grad", "grid_grad"],
|
||||||
|
source=source,
|
||||||
|
atomic_outputs=True,
|
||||||
|
)
|
||||||
|
|
||||||
// Reduce across each simdgroup first.
|
@grid_sample.vjp
|
||||||
// This is much faster than relying purely on atomics.
|
def grid_sample_vjp(primals, cotangent, _):
|
||||||
gix = simd_sum(gix);
|
x, grid = primals
|
||||||
giy = simd_sum(giy);
|
B, _, _, C = x.shape
|
||||||
|
_, gN, gM, D = grid.shape
|
||||||
|
|
||||||
if (thread_index_in_simdgroup == 0) {
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
|
||||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
# pad the output channels to simd group size
|
||||||
}
|
# so that our `simd_sum`s don't overlap.
|
||||||
"""
|
simdgroup_size = 32
|
||||||
kernel = mx.fast.metal_kernel(
|
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||||
name="grid_sample_grad",
|
grid_size = B * gN * gM * C_padded
|
||||||
input_names=["x", "grid", "cotangent"],
|
outputs = kernel(
|
||||||
output_names=["x_grad", "grid_grad"],
|
inputs=[x, grid, cotangent],
|
||||||
source=source,
|
template=[("T", x.dtype)],
|
||||||
atomic_outputs=True,
|
output_shapes=[x.shape, grid.shape],
|
||||||
)
|
output_dtypes=[x.dtype, x.dtype],
|
||||||
# pad the output channels to simd group size
|
grid=(grid_size, 1, 1),
|
||||||
# so that our `simd_sum`s don't overlap.
|
threadgroup=(256, 1, 1),
|
||||||
simdgroup_size = 32
|
init_value=0,
|
||||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
)
|
||||||
grid_size = B * gN * gM * C_padded
|
return outputs[0], outputs[1]
|
||||||
outputs = kernel(
|
|
||||||
inputs=[x, grid, cotangent],
|
|
||||||
template=[("T", x.dtype)],
|
|
||||||
output_shapes=[x.shape, grid.shape],
|
|
||||||
output_dtypes=[x.dtype, x.dtype],
|
|
||||||
grid=(grid_size, 1, 1),
|
|
||||||
threadgroup=(256, 1, 1),
|
|
||||||
init_value=0,
|
|
||||||
)
|
|
||||||
return outputs[0], outputs[1]
|
|
||||||
|
|
||||||
There's an even larger speed up for the vjp:
|
There's an even larger speed up for the vjp:
|
||||||
|
|
||||||
|
|||||||
@@ -397,11 +397,11 @@ below.
|
|||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "axpby_" << "general_" << type_to_name(out);
|
kname << "axpby_" << "general_" << type_to_name(out);
|
||||||
|
|
||||||
// Make sure the metal library is available
|
// Load the metal library
|
||||||
d.register_library("mlx_ext");
|
auto lib = d.get_library("mlx_ext");
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname.str(), lib);
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ Array
|
|||||||
array.ndim
|
array.ndim
|
||||||
array.shape
|
array.shape
|
||||||
array.size
|
array.size
|
||||||
|
array.real
|
||||||
|
array.imag
|
||||||
array.abs
|
array.abs
|
||||||
array.all
|
array.all
|
||||||
array.any
|
array.any
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ Linear Algebra
|
|||||||
cross
|
cross
|
||||||
qr
|
qr
|
||||||
svd
|
svd
|
||||||
|
eigvals
|
||||||
|
eig
|
||||||
eigvalsh
|
eigvalsh
|
||||||
eigh
|
eigh
|
||||||
lu
|
lu
|
||||||
|
|||||||
@@ -172,11 +172,11 @@ void Axpby::eval_gpu(
|
|||||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||||
kname << type_to_name(out);
|
kname << type_to_name(out);
|
||||||
|
|
||||||
// Make sure the metal library is available
|
// Load the metal library
|
||||||
d.register_library("mlx_ext");
|
auto lib = d.get_library("mlx_ext");
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname.str(), lib);
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||||
|
|
||||||
# Define MLX_VERSION only in the version.cpp file.
|
# Define MLX_VERSION only in the version.cpp file.
|
||||||
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||||
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
||||||
|
|
||||||
@@ -55,6 +55,9 @@ endif()
|
|||||||
|
|
||||||
if(MLX_BUILD_CUDA)
|
if(MLX_BUILD_CUDA)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||||
|
else()
|
||||||
|
target_sources(mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||||
|
|||||||
@@ -224,6 +224,10 @@ class array {
|
|||||||
// Not copyable
|
// Not copyable
|
||||||
Data(const Data& d) = delete;
|
Data(const Data& d) = delete;
|
||||||
Data& operator=(const Data& d) = delete;
|
Data& operator=(const Data& d) = delete;
|
||||||
|
Data(Data&& o) : buffer(o.buffer), d(o.d) {
|
||||||
|
o.buffer = allocator::Buffer(nullptr);
|
||||||
|
o.d = [](allocator::Buffer) {};
|
||||||
|
}
|
||||||
~Data() {
|
~Data() {
|
||||||
d(buffer);
|
d(buffer);
|
||||||
}
|
}
|
||||||
|
|||||||
157
mlx/backend/common/buffer_cache.h
Normal file
157
mlx/backend/common/buffer_cache.h
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class BufferCache {
|
||||||
|
public:
|
||||||
|
BufferCache(
|
||||||
|
size_t page_size,
|
||||||
|
std::function<size_t(T*)> get_size,
|
||||||
|
std::function<void(T*)> free)
|
||||||
|
: page_size_(page_size),
|
||||||
|
get_size_(std::move(get_size)),
|
||||||
|
free_(std::move(free)) {}
|
||||||
|
|
||||||
|
~BufferCache() {
|
||||||
|
clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferCache(const BufferCache&) = delete;
|
||||||
|
BufferCache& operator=(const BufferCache&) = delete;
|
||||||
|
|
||||||
|
T* reuse_from_cache(size_t size) {
|
||||||
|
// Find the closest buffer in pool.
|
||||||
|
auto it = buffer_pool_.lower_bound(size);
|
||||||
|
if (it == buffer_pool_.end() ||
|
||||||
|
it->first >= std::min(2 * size, size + 2 * page_size_)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect from the cache.
|
||||||
|
T* buf = it->second->buf;
|
||||||
|
pool_size_ -= it->first;
|
||||||
|
|
||||||
|
// Remove from record.
|
||||||
|
remove_from_list(it->second);
|
||||||
|
buffer_pool_.erase(it);
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
void recycle_to_cache(T* buf) {
|
||||||
|
assert(buf);
|
||||||
|
// Add to cache.
|
||||||
|
BufferHolder* bh = new BufferHolder(buf);
|
||||||
|
add_at_head(bh);
|
||||||
|
size_t size = get_size_(buf);
|
||||||
|
pool_size_ += size;
|
||||||
|
buffer_pool_.emplace(size, bh);
|
||||||
|
}
|
||||||
|
|
||||||
|
int release_cached_buffers(size_t min_bytes_to_free) {
|
||||||
|
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||||
|
return clear();
|
||||||
|
} else {
|
||||||
|
int n_release = 0;
|
||||||
|
size_t total_bytes_freed = 0;
|
||||||
|
|
||||||
|
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||||
|
// Release buffer.
|
||||||
|
size_t size = get_size_(tail_->buf);
|
||||||
|
total_bytes_freed += size;
|
||||||
|
free_(tail_->buf);
|
||||||
|
n_release++;
|
||||||
|
|
||||||
|
// Remove from record.
|
||||||
|
auto its = buffer_pool_.equal_range(size);
|
||||||
|
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
|
||||||
|
return el.second == tail_;
|
||||||
|
});
|
||||||
|
assert(it != buffer_pool_.end());
|
||||||
|
buffer_pool_.erase(it);
|
||||||
|
remove_from_list(tail_);
|
||||||
|
}
|
||||||
|
|
||||||
|
pool_size_ -= total_bytes_freed;
|
||||||
|
return n_release;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int clear() {
|
||||||
|
int n_release = 0;
|
||||||
|
for (auto& [size, holder] : buffer_pool_) {
|
||||||
|
free_(holder->buf);
|
||||||
|
n_release++;
|
||||||
|
delete holder;
|
||||||
|
}
|
||||||
|
buffer_pool_.clear();
|
||||||
|
pool_size_ = 0;
|
||||||
|
head_ = nullptr;
|
||||||
|
tail_ = nullptr;
|
||||||
|
return n_release;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t cache_size() const {
|
||||||
|
return pool_size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t page_size() const {
|
||||||
|
return page_size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct BufferHolder {
|
||||||
|
public:
|
||||||
|
explicit BufferHolder(T* buf_) : buf(buf_) {}
|
||||||
|
|
||||||
|
BufferHolder* prev{nullptr};
|
||||||
|
BufferHolder* next{nullptr};
|
||||||
|
T* buf;
|
||||||
|
};
|
||||||
|
|
||||||
|
void add_at_head(BufferHolder* to_add) {
|
||||||
|
if (!head_) {
|
||||||
|
head_ = to_add;
|
||||||
|
tail_ = to_add;
|
||||||
|
} else {
|
||||||
|
head_->prev = to_add;
|
||||||
|
to_add->next = head_;
|
||||||
|
head_ = to_add;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void remove_from_list(BufferHolder* to_remove) {
|
||||||
|
if (to_remove->prev && to_remove->next) { // if middle
|
||||||
|
to_remove->prev->next = to_remove->next;
|
||||||
|
to_remove->next->prev = to_remove->prev;
|
||||||
|
} else if (to_remove->prev && to_remove == tail_) { // if tail
|
||||||
|
tail_ = to_remove->prev;
|
||||||
|
tail_->next = nullptr;
|
||||||
|
} else if (to_remove == head_ && to_remove->next) { // if head
|
||||||
|
head_ = to_remove->next;
|
||||||
|
head_->prev = nullptr;
|
||||||
|
} else if (to_remove == head_ && to_remove == tail_) { // if only element
|
||||||
|
head_ = nullptr;
|
||||||
|
tail_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
delete to_remove;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||||
|
BufferHolder* head_{nullptr};
|
||||||
|
BufferHolder* tail_{nullptr};
|
||||||
|
size_t pool_size_{0};
|
||||||
|
|
||||||
|
const size_t page_size_;
|
||||||
|
std::function<size_t(T*)> get_size_;
|
||||||
|
std::function<void(T*)> free_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/graph_utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -79,55 +78,6 @@ std::string get_type_string(Dtype d) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string build_lib_name(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs,
|
|
||||||
const std::vector<array>& tape,
|
|
||||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
|
||||||
NodeNamer namer;
|
|
||||||
std::ostringstream os;
|
|
||||||
std::ostringstream constant_hasher;
|
|
||||||
|
|
||||||
// Fill the input names. This is not really necessary, I just like having A,
|
|
||||||
// B, C, ... as the inputs.
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
namer.get_name(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The primitives describing the tape. For unary and binary primitives this
|
|
||||||
// must be enough to describe the full computation.
|
|
||||||
for (auto& a : tape) {
|
|
||||||
// name and type of output
|
|
||||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
|
||||||
// computation performed
|
|
||||||
a.primitive().print(os);
|
|
||||||
// name of inputs to the function
|
|
||||||
for (auto& inp : a.inputs()) {
|
|
||||||
os << namer.get_name(inp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
os << "_";
|
|
||||||
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
|
||||||
os << "C";
|
|
||||||
print_constant(constant_hasher, x);
|
|
||||||
} else {
|
|
||||||
os << (is_scalar(x) ? "S" : "V");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
os << "_";
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
os << kindof(x.dtype()) << x.itemsize();
|
|
||||||
}
|
|
||||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
|
||||||
|
|
||||||
return os.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool compiled_check_contiguity(
|
bool compiled_check_contiguity(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const Shape& shape) {
|
const Shape& shape) {
|
||||||
@@ -159,8 +109,7 @@ bool compiled_check_contiguity(
|
|||||||
void compiled_allocate_outputs(
|
void compiled_allocate_outputs(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::vector<array>& inputs_,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
|
||||||
bool contiguous) {
|
bool contiguous) {
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
int o = 0;
|
int o = 0;
|
||||||
@@ -175,8 +124,7 @@ void compiled_allocate_outputs(
|
|||||||
// - Donatable
|
// - Donatable
|
||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||||
in.is_donatable() &&
|
in.is_donatable() && is_constant(i)) {
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
|
||||||
outputs[o++].copy_shared_buffer(in);
|
outputs[o++].copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
// Get representative input flags to properly set non-donated outputs
|
// Get representative input flags to properly set non-donated outputs
|
||||||
@@ -204,7 +152,7 @@ void compiled_allocate_outputs(
|
|||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
is_constant(i)) {
|
||||||
outputs[o].copy_shared_buffer(
|
outputs[o].copy_shared_buffer(
|
||||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||||
o++;
|
o++;
|
||||||
@@ -216,4 +164,74 @@ void compiled_allocate_outputs(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const array& out,
|
||||||
|
const std::function<bool(size_t)>& is_constant) {
|
||||||
|
const Shape& shape = out.shape();
|
||||||
|
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||||
|
if (contiguous) {
|
||||||
|
return {true, shape, {}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Strides> strides_vec{out.strides()};
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
// Skip constants.
|
||||||
|
if (is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip scalar inputs.
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
if (is_scalar(x)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast the inputs to the output shape.
|
||||||
|
Strides xstrides;
|
||||||
|
size_t j = 0;
|
||||||
|
for (; j < shape.size() - x.ndim(); ++j) {
|
||||||
|
if (shape[j] == 1) {
|
||||||
|
xstrides.push_back(out.strides()[j]);
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
||||||
|
if (x.shape(i) == 1) {
|
||||||
|
if (shape[j] == 1) {
|
||||||
|
xstrides.push_back(out.strides()[j]);
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(x.strides()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
strides_vec.push_back(std::move(xstrides));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
|
||||||
|
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool compiled_use_large_index(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
bool contiguous) {
|
||||||
|
if (contiguous) {
|
||||||
|
size_t max_size = 0;
|
||||||
|
for (const auto& in : inputs) {
|
||||||
|
max_size = std::max(max_size, in.data_size());
|
||||||
|
}
|
||||||
|
return max_size > UINT32_MAX;
|
||||||
|
} else {
|
||||||
|
size_t max_size = 0;
|
||||||
|
for (const auto& o : outputs) {
|
||||||
|
max_size = std::max(max_size, o.size());
|
||||||
|
}
|
||||||
|
return max_size > UINT32_MAX;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <sstream>
|
|
||||||
#include <unordered_set>
|
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -14,12 +13,6 @@ inline bool is_static_cast(const Primitive& p) {
|
|||||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string build_lib_name(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs,
|
|
||||||
const std::vector<array>& tape,
|
|
||||||
const std::unordered_set<uintptr_t>& constant_ids);
|
|
||||||
|
|
||||||
std::string get_type_string(Dtype d);
|
std::string get_type_string(Dtype d);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -60,8 +53,19 @@ bool compiled_check_contiguity(
|
|||||||
void compiled_allocate_outputs(
|
void compiled_allocate_outputs(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::vector<array>& inputs_,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
bool contiguous);
|
||||||
|
|
||||||
|
// Collapse contiguous dims ignoring scalars and constants.
|
||||||
|
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const array& out,
|
||||||
|
const std::function<bool(size_t)>& is_constant);
|
||||||
|
|
||||||
|
// Return whether the kernel should use large index.
|
||||||
|
bool compiled_use_large_index(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
bool contiguous);
|
bool contiguous);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
|||||||
if (ctype == CopyType::Vector) {
|
if (ctype == CopyType::Vector) {
|
||||||
// If the input is donateable, we are doing a vector copy and the types
|
// If the input is donateable, we are doing a vector copy and the types
|
||||||
// have the same size, then the input buffer can hold the output.
|
// have the same size, then the input buffer can hold the output.
|
||||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
if (is_donatable(in, out)) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
78
mlx/backend/common/matmul.h
Normal file
78
mlx/backend/common/matmul.h
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||||
|
const array& a,
|
||||||
|
const array& b) {
|
||||||
|
// Get and check the shape for the batched dims
|
||||||
|
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||||
|
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||||
|
if (A_bshape != B_bshape) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||||
|
<< a.shape() << ", B " << b.shape() << ".";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||||
|
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||||
|
|
||||||
|
auto [batch_shape, batch_strides] =
|
||||||
|
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||||
|
|
||||||
|
auto a_batch_strides = batch_strides[0];
|
||||||
|
auto b_batch_strides = batch_strides[1];
|
||||||
|
|
||||||
|
if (batch_shape.empty()) {
|
||||||
|
batch_shape.push_back(1);
|
||||||
|
a_batch_strides.push_back(0);
|
||||||
|
b_batch_strides.push_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::tuple<Shape, Strides, Strides, Strides>
|
||||||
|
collapse_batches(const array& a, const array& b, const array& c) {
|
||||||
|
// Get and check the shape for the batched dims
|
||||||
|
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||||
|
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||||
|
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
|
||||||
|
if (A_bshape != B_bshape || A_bshape != C_bshape) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||||
|
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||||
|
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||||
|
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||||
|
|
||||||
|
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||||
|
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||||
|
|
||||||
|
auto A_batch_stride = batch_strides[0];
|
||||||
|
auto B_batch_stride = batch_strides[1];
|
||||||
|
auto C_batch_stride = batch_strides[2];
|
||||||
|
|
||||||
|
if (batch_shape.empty()) {
|
||||||
|
batch_shape.push_back(1);
|
||||||
|
A_batch_stride.push_back(0);
|
||||||
|
B_batch_stride.push_back(0);
|
||||||
|
C_batch_stride.push_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(
|
||||||
|
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
26
mlx/backend/common/unary.h
Normal file
26
mlx/backend/common/unary.h
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
inline void set_unary_output_data(const array& in, array& out) {
|
||||||
|
if (in.flags().contiguous) {
|
||||||
|
if (is_donatable(in, out)) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc(in.data_size() * out.itemsize()),
|
||||||
|
in.data_size(),
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,9 +1,16 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
std::string get_primitive_string(Primitive* primitive) {
|
||||||
|
std::ostringstream op_t;
|
||||||
|
primitive->print(op_t);
|
||||||
|
return op_t.str();
|
||||||
|
}
|
||||||
|
|
||||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const std::vector<Strides>& strides,
|
const std::vector<Strides>& strides,
|
||||||
@@ -101,4 +108,105 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
|||||||
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
||||||
|
int pows[3] = {0, 0, 0};
|
||||||
|
int sum = 0;
|
||||||
|
while (true) {
|
||||||
|
int presum = sum;
|
||||||
|
// Check all the pows
|
||||||
|
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||||
|
pows[0]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||||
|
pows[1]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||||
|
pows[2]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == presum || sum == pow2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
|
||||||
|
// Dims with strides of 0 are ignored as they
|
||||||
|
// correspond to broadcasted dimensions
|
||||||
|
size_t grid_x = 1;
|
||||||
|
size_t grid_y = 1;
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
if (strides[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (grid_x * shape[i] < UINT32_MAX) {
|
||||||
|
grid_x *= shape[i];
|
||||||
|
} else {
|
||||||
|
grid_y *= shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||||
|
throw std::runtime_error("Unable to safely factor shape.");
|
||||||
|
}
|
||||||
|
if (grid_y > grid_x) {
|
||||||
|
std::swap(grid_x, grid_y);
|
||||||
|
}
|
||||||
|
return std::make_tuple(
|
||||||
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Dims get_2d_grid_dims_common(
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
|
size_t divisor) {
|
||||||
|
// Compute the 2d grid dimensions such that the total size of the grid is
|
||||||
|
// divided by divisor.
|
||||||
|
size_t grid_x = 1;
|
||||||
|
size_t grid_y = 1;
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
if (strides[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// No need to add this shape we can just remove it from the divisor.
|
||||||
|
if (divisor % shape[i] == 0) {
|
||||||
|
divisor /= shape[i];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (grid_x * shape[i] < UINT32_MAX) {
|
||||||
|
grid_x *= shape[i];
|
||||||
|
} else {
|
||||||
|
grid_y *= shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (divisor > 1) {
|
||||||
|
if (grid_x % divisor == 0) {
|
||||||
|
grid_x /= divisor;
|
||||||
|
divisor = 1;
|
||||||
|
} else if (grid_y % divisor == 0) {
|
||||||
|
grid_y /= divisor;
|
||||||
|
divisor = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
|
||||||
|
throw std::runtime_error("Unable to safely factor shape.");
|
||||||
|
}
|
||||||
|
if (grid_y > grid_x) {
|
||||||
|
std::swap(grid_x, grid_y);
|
||||||
|
}
|
||||||
|
return std::make_tuple(
|
||||||
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,12 +2,15 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
std::string get_primitive_string(Primitive* primitive);
|
||||||
|
|
||||||
inline int64_t
|
inline int64_t
|
||||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||||
int64_t loc = 0;
|
int64_t loc = 0;
|
||||||
@@ -70,6 +73,28 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
|||||||
const array& a,
|
const array& a,
|
||||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||||
|
|
||||||
|
// Compute the thread block dimensions which fit the given
|
||||||
|
// input dimensions.
|
||||||
|
// - The thread block dimensions will be powers of two
|
||||||
|
// - The thread block size will be less than 2^pow2
|
||||||
|
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
|
||||||
|
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||||
|
|
||||||
|
// Computes a 2D grid where each element is < UINT_MAX
|
||||||
|
// Assumes:
|
||||||
|
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
||||||
|
// - shape and strides correspond to a contiguous (no holes) but
|
||||||
|
// possibly broadcasted array
|
||||||
|
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
|
||||||
|
|
||||||
|
// Same as above but we do an implicit division with divisor.
|
||||||
|
// Basically, equivalent to factorizing
|
||||||
|
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
||||||
|
Dims get_2d_grid_dims_common(
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
|
size_t divisor);
|
||||||
|
|
||||||
struct ContiguousIterator {
|
struct ContiguousIterator {
|
||||||
inline void step() {
|
inline void step() {
|
||||||
int dims = shape_.size();
|
int dims = shape_.size();
|
||||||
@@ -165,4 +190,11 @@ void shared_buffer_reshape(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const Strides& out_strides,
|
const Strides& out_strides,
|
||||||
array& out);
|
array& out);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
|
||||||
|
vec.erase(std::next(vec.begin(), index));
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
|
|||||||
@@ -14,10 +14,8 @@ template <typename InT, typename OpT>
|
|||||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||||
auto axis_size = in.shape()[axis];
|
auto axis_size = in.shape()[axis];
|
||||||
auto axis_stride = in.strides()[axis];
|
auto axis_stride = in.strides()[axis];
|
||||||
Strides strides = in.strides();
|
Strides strides = remove_index(in.strides(), axis);
|
||||||
Shape shape = in.shape();
|
Shape shape = remove_index(in.shape(), axis);
|
||||||
strides.erase(strides.begin() + axis);
|
|
||||||
shape.erase(shape.begin() + axis);
|
|
||||||
auto in_ptr = in.data<InT>();
|
auto in_ptr = in.data<InT>();
|
||||||
auto out_ptr = out.data<uint32_t>();
|
auto out_ptr = out.data<uint32_t>();
|
||||||
|
|
||||||
|
|||||||
@@ -146,18 +146,9 @@ inline void build_kernel(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<array>& outputs,
|
const std::vector<array>& outputs,
|
||||||
const std::vector<array>& tape,
|
const std::vector<array>& tape,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
bool contiguous,
|
bool contiguous,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
// All outputs should have the exact same shape and will be row contiguous
|
|
||||||
auto output_shape = outputs[0].shape();
|
|
||||||
auto output_strides = outputs[0].strides();
|
|
||||||
|
|
||||||
// Constants are scalars that are captured by value and cannot change
|
|
||||||
auto is_constant = [&constant_ids](const array& x) {
|
|
||||||
return constant_ids.find(x.id()) != constant_ids.end();
|
|
||||||
};
|
|
||||||
|
|
||||||
NodeNamer namer;
|
NodeNamer namer;
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
@@ -170,14 +161,15 @@ inline void build_kernel(
|
|||||||
|
|
||||||
// Add the input arguments
|
// Add the input arguments
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
auto& xname = namer.get_name(x);
|
|
||||||
|
|
||||||
// Skip constants from the input list
|
// Skip constants from the input list
|
||||||
if (is_constant(x)) {
|
if (is_constant(i)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
auto tstr = get_type_string(x.dtype());
|
auto tstr = get_type_string(x.dtype());
|
||||||
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
||||||
<< "];" << std::endl;
|
<< "];" << std::endl;
|
||||||
@@ -211,10 +203,11 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read the inputs in tmps
|
// Read the inputs in tmps
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
if (is_constant(x)) {
|
if (is_constant(i)) {
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||||
print_constant(os, x);
|
print_constant(os, x);
|
||||||
os << ";" << std::endl;
|
os << ";" << std::endl;
|
||||||
@@ -264,8 +257,9 @@ inline void build_kernel(
|
|||||||
} else {
|
} else {
|
||||||
for (int d = ndim - 1; d >= 0; --d) {
|
for (int d = ndim - 1; d >= 0; --d) {
|
||||||
// Update pointers
|
// Update pointers
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
if (is_constant(x) || is_scalar(x)) {
|
const auto& x = inputs[i];
|
||||||
|
if (is_constant(i) || is_scalar(x)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
@@ -287,65 +281,37 @@ inline void build_kernel(
|
|||||||
void Compiled::eval_cpu(
|
void Compiled::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
if (kernel_lib_.empty()) {
|
|
||||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Figure out which kernel we are using
|
|
||||||
auto& shape = outputs[0].shape();
|
|
||||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
|
||||||
// Handle all broadcasting and collect function input arguments
|
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||||
|
// handle all broadcasting.
|
||||||
|
auto [contiguous, shape, strides] =
|
||||||
|
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||||
|
|
||||||
|
// Collect function input arguments.
|
||||||
std::vector<void*> args;
|
std::vector<void*> args;
|
||||||
std::vector<std::vector<size_t>> strides;
|
int strides_index = 1;
|
||||||
for (int i = 0; i < inputs.size(); i++) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
// Skip constants.
|
if (is_constant_(i)) {
|
||||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto& x = inputs[i];
|
const auto& x = inputs[i];
|
||||||
encoder.set_input_array(x);
|
encoder.set_input_array(x);
|
||||||
args.push_back((void*)x.data<void>());
|
args.push_back((void*)x.data<void>());
|
||||||
|
if (!contiguous && !is_scalar(x)) {
|
||||||
if (contiguous || is_scalar(x)) {
|
args.push_back(strides[strides_index++].data());
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Broadcast the input to the output shape.
|
|
||||||
std::vector<size_t> xstrides;
|
|
||||||
int j = 0;
|
|
||||||
for (; j < shape.size() - x.ndim(); j++) {
|
|
||||||
if (shape[j] == 1) {
|
|
||||||
xstrides.push_back(outputs[0].strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
|
||||||
if (x.shape(i) == 1) {
|
|
||||||
if (shape[j] == 1) {
|
|
||||||
xstrides.push_back(outputs[0].strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(x.strides()[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
strides.push_back(std::move(xstrides));
|
|
||||||
args.push_back(strides.back().data());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the kernel name from the lib
|
// Get the kernel name from the lib
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
kernel_name += std::to_string(shape.size());
|
kernel_name += std::to_string(ndim);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the function
|
// Get the function
|
||||||
auto fn_ptr = compile(kernel_name, [&]() {
|
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
|
||||||
std::ostringstream kernel;
|
std::ostringstream kernel;
|
||||||
kernel << get_kernel_preamble() << std::endl;
|
kernel << get_kernel_preamble() << std::endl;
|
||||||
kernel << "extern \"C\" {" << std::endl;
|
kernel << "extern \"C\" {" << std::endl;
|
||||||
@@ -355,7 +321,7 @@ void Compiled::eval_cpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
contiguous,
|
contiguous,
|
||||||
ndim);
|
ndim);
|
||||||
// Close extern "C"
|
// Close extern "C"
|
||||||
@@ -363,26 +329,22 @@ void Compiled::eval_cpu(
|
|||||||
return kernel.str();
|
return kernel.str();
|
||||||
});
|
});
|
||||||
|
|
||||||
compiled_allocate_outputs(
|
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
|
||||||
|
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
args.push_back(x.data<void>());
|
args.push_back(x.data<void>());
|
||||||
encoder.set_output_array(x);
|
encoder.set_output_array(x);
|
||||||
}
|
}
|
||||||
Shape out_shape;
|
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
out_shape = outputs[0].shape();
|
args.push_back((void*)shape.data());
|
||||||
args.push_back((void*)out_shape.data());
|
|
||||||
} else {
|
} else {
|
||||||
args.push_back((void*)outputs[0].data_size());
|
args.push_back((void*)outputs[0].data_size());
|
||||||
}
|
}
|
||||||
auto fun = (void (*)(void**))fn_ptr;
|
auto fun = (void (*)(void**))fn_ptr;
|
||||||
encoder.dispatch(
|
encoder.dispatch([fun,
|
||||||
[fun,
|
args = std::move(args),
|
||||||
args = std::move(args),
|
strides = std::move(strides),
|
||||||
strides = std::move(strides),
|
shape = std::move(shape)]() mutable { fun(args.data()); });
|
||||||
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ void slow_conv_1D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -60,7 +61,8 @@ void slow_conv_1D(
|
|||||||
out_stride_O = out.strides()[2],
|
out_stride_O = out.strides()[2],
|
||||||
|
|
||||||
flip,
|
flip,
|
||||||
padding = padding[0],
|
padding_lo = padding_lo[0],
|
||||||
|
padding_hi = padding_hi[0],
|
||||||
wt_stride = wt_strides[0],
|
wt_stride = wt_strides[0],
|
||||||
wt_dilation = wt_dilation[0],
|
wt_dilation = wt_dilation[0],
|
||||||
in_dilation = in_dilation[0]]() mutable {
|
in_dilation = in_dilation[0]]() mutable {
|
||||||
@@ -77,7 +79,7 @@ void slow_conv_1D(
|
|||||||
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
||||||
|
|
||||||
int wh_flip = flip ? (wH - wh - 1) : wh;
|
int wh_flip = flip ? (wH - wh - 1) : wh;
|
||||||
int ih = oh * wt_stride - padding + wh_flip * wt_dilation;
|
int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation;
|
||||||
|
|
||||||
auto ih_div = std::div(ih, in_dilation);
|
auto ih_div = std::div(ih, in_dilation);
|
||||||
|
|
||||||
@@ -109,7 +111,8 @@ void slow_conv_2D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -120,230 +123,235 @@ void slow_conv_2D(
|
|||||||
encoder.set_input_array(wt);
|
encoder.set_input_array(wt);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
encoder.dispatch([st_wt_ptr = wt.data<T>(),
|
encoder.dispatch(
|
||||||
st_in_ptr = in.data<T>(),
|
[st_wt_ptr = wt.data<T>(),
|
||||||
st_out_ptr = out.data<T>(),
|
st_in_ptr = in.data<T>(),
|
||||||
|
st_out_ptr = out.data<T>(),
|
||||||
|
|
||||||
N = in.shape(
|
N = in.shape(0), // Batch size, should be the same as out.shape(0)
|
||||||
0), // Batch size, should be the same as out.shape(0)
|
iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
|
||||||
iH = 1 +
|
iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
|
||||||
in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
|
C = in.shape(3), // In channels
|
||||||
iW = 1 +
|
oH = out.shape(1), // Output spatial dim
|
||||||
in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
|
oW = out.shape(2), // Output spatial dim
|
||||||
C = in.shape(3), // In channels
|
O = wt.shape(0), // Out channels
|
||||||
oH = out.shape(1), // Output spatial dim
|
wH = wt.shape(1), // Weight spatial dim
|
||||||
oW = out.shape(2), // Output spatial dim
|
wW = wt.shape(2), // Weight spatial dim
|
||||||
O = wt.shape(0), // Out channels
|
|
||||||
wH = wt.shape(1), // Weight spatial dim
|
|
||||||
wW = wt.shape(2), // Weight spatial dim
|
|
||||||
|
|
||||||
groups = in.shape(3) / wt.shape(3),
|
groups = in.shape(3) / wt.shape(3),
|
||||||
C_per_group = wt.shape(3),
|
C_per_group = wt.shape(3),
|
||||||
|
|
||||||
in_stride_N = in.strides()[0],
|
in_stride_N = in.strides()[0],
|
||||||
in_stride_H = in.strides()[1],
|
in_stride_H = in.strides()[1],
|
||||||
in_stride_W = in.strides()[2],
|
in_stride_W = in.strides()[2],
|
||||||
in_stride_C = in.strides()[3],
|
in_stride_C = in.strides()[3],
|
||||||
|
|
||||||
wt_stride_O = wt.strides()[0],
|
wt_stride_O = wt.strides()[0],
|
||||||
wt_stride_H = wt.strides()[1],
|
wt_stride_H = wt.strides()[1],
|
||||||
wt_stride_W = wt.strides()[2],
|
wt_stride_W = wt.strides()[2],
|
||||||
wt_stride_C = wt.strides()[3],
|
wt_stride_C = wt.strides()[3],
|
||||||
|
|
||||||
out_stride_N = out.strides()[0],
|
out_stride_N = out.strides()[0],
|
||||||
out_stride_H = out.strides()[1],
|
out_stride_H = out.strides()[1],
|
||||||
out_stride_W = out.strides()[2],
|
out_stride_W = out.strides()[2],
|
||||||
out_stride_O = out.strides()[3],
|
out_stride_O = out.strides()[3],
|
||||||
|
|
||||||
padding,
|
padding_lo,
|
||||||
wt_strides,
|
padding_hi,
|
||||||
wt_dilation,
|
wt_strides,
|
||||||
in_dilation,
|
wt_dilation,
|
||||||
flip]() mutable {
|
in_dilation,
|
||||||
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
|
flip]() mutable {
|
||||||
|
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
|
||||||
|
|
||||||
const int O_per_group = O / groups;
|
const int O_per_group = O / groups;
|
||||||
auto pt_conv_no_checks = [&](const T* in_ptr,
|
auto pt_conv_no_checks =
|
||||||
const T* wt_ptr,
|
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||||
T* out_ptr,
|
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||||
int oh,
|
int ih_base = oh * wt_strides[0] - padding_lo[0];
|
||||||
int ow) {
|
int iw_base = ow * wt_strides[1] - padding_lo[1];
|
||||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
|
||||||
int ih_base = oh * wt_strides[0] - padding[0];
|
|
||||||
int iw_base = ow * wt_strides[1] - padding[1];
|
|
||||||
|
|
||||||
for (int g = 0; g < groups; ++g) {
|
for (int g = 0; g < groups; ++g) {
|
||||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||||
float r = 0.;
|
float r = 0.;
|
||||||
|
|
||||||
for (int wh = 0; wh < wH; ++wh) {
|
for (int wh = 0; wh < wH; ++wh) {
|
||||||
for (int ww = 0; ww < wW; ++ww) {
|
for (int ww = 0; ww < wW; ++ww) {
|
||||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||||
|
|
||||||
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
const T* wt_ptr_pt =
|
||||||
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
|
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||||
|
const T* in_ptr_pt =
|
||||||
|
in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||||
|
|
||||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
|
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
++c) {
|
||||||
static_cast<float>(
|
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
static_cast<float>(
|
||||||
} // c
|
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||||
} // ww
|
} // c
|
||||||
} // wh
|
} // ww
|
||||||
|
} // wh
|
||||||
|
|
||||||
out_ptr[0] = static_cast<T>(r);
|
out_ptr[0] = static_cast<T>(r);
|
||||||
out_ptr += out_stride_O;
|
out_ptr += out_stride_O;
|
||||||
wt_ptr += wt_stride_O;
|
wt_ptr += wt_stride_O;
|
||||||
} // o
|
} // o
|
||||||
} // g
|
} // g
|
||||||
};
|
};
|
||||||
|
|
||||||
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||||
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
|
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
|
||||||
|
|
||||||
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
|
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
|
||||||
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
|
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
|
||||||
|
|
||||||
int f_wgt_jump_h =
|
int f_wgt_jump_h =
|
||||||
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
|
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
|
||||||
int f_wgt_jump_w =
|
int f_wgt_jump_w =
|
||||||
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
||||||
|
|
||||||
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
int f_out_jump_h =
|
||||||
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
||||||
|
int f_out_jump_w =
|
||||||
|
std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
||||||
|
|
||||||
std::vector<int> base_h(f_out_jump_h);
|
std::vector<int> base_h(f_out_jump_h);
|
||||||
std::vector<int> base_w(f_out_jump_w);
|
std::vector<int> base_w(f_out_jump_w);
|
||||||
|
|
||||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||||
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
|
int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h;
|
||||||
|
|
||||||
int wh_base = 0;
|
int wh_base = 0;
|
||||||
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
|
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
|
||||||
wh_base++;
|
wh_base++;
|
||||||
ih_loop += jump_h;
|
ih_loop += jump_h;
|
||||||
}
|
}
|
||||||
|
|
||||||
base_h[i] = wh_base;
|
base_h[i] = wh_base;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||||
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
|
int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w;
|
||||||
|
|
||||||
int ww_base = 0;
|
int ww_base = 0;
|
||||||
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
|
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
|
||||||
ww_base++;
|
ww_base++;
|
||||||
iw_loop += jump_w;
|
iw_loop += jump_w;
|
||||||
}
|
}
|
||||||
|
|
||||||
base_w[j] = ww_base;
|
base_w[j] = ww_base;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto pt_conv_all_checks =
|
auto pt_conv_all_checks =
|
||||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||||
|
|
||||||
int ih_base = oh * wt_strides[0] - padding[0];
|
int ih_base = oh * wt_strides[0] - padding_lo[0];
|
||||||
int iw_base = ow * wt_strides[1] - padding[1];
|
int iw_base = ow * wt_strides[1] - padding_lo[1];
|
||||||
|
|
||||||
int wh_base = base_h[oh % f_out_jump_h];
|
int wh_base = base_h[oh % f_out_jump_h];
|
||||||
int ww_base = base_w[ow % f_out_jump_w];
|
int ww_base = base_w[ow % f_out_jump_w];
|
||||||
|
|
||||||
for (int g = 0; g < groups; ++g) {
|
for (int g = 0; g < groups; ++g) {
|
||||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||||
float r = 0.;
|
float r = 0.;
|
||||||
|
|
||||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||||
|
|
||||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||||
const T* wt_ptr_pt =
|
const T* wt_ptr_pt =
|
||||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||||
|
|
||||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||||
|
|
||||||
const T* in_ptr_pt =
|
const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H +
|
||||||
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
iw_dil * in_stride_W;
|
||||||
|
|
||||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||||
++c) {
|
++c) {
|
||||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||||
static_cast<float>(
|
static_cast<float>(
|
||||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||||
} // c
|
} // c
|
||||||
|
|
||||||
} // ih, iw check
|
} // ih, iw check
|
||||||
} // ww
|
} // ww
|
||||||
} // wh
|
} // wh
|
||||||
|
|
||||||
out_ptr[0] = static_cast<T>(r);
|
out_ptr[0] = static_cast<T>(r);
|
||||||
out_ptr += out_stride_O;
|
out_ptr += out_stride_O;
|
||||||
wt_ptr += wt_stride_O;
|
wt_ptr += wt_stride_O;
|
||||||
} // o
|
} // o
|
||||||
} // g
|
} // g
|
||||||
};
|
};
|
||||||
|
|
||||||
int oH_border_0 = 0;
|
int oH_border_0 = 0;
|
||||||
int oH_border_1 =
|
int oH_border_1 = is_idil_one
|
||||||
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
|
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
|
||||||
int oH_border_2 = std::max(
|
: oH;
|
||||||
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
|
int oH_border_2 = std::max(
|
||||||
int oH_border_3 = oH;
|
oH_border_1,
|
||||||
|
(iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]);
|
||||||
|
int oH_border_3 = oH;
|
||||||
|
|
||||||
int oW_border_0 = 0;
|
int oW_border_0 = 0;
|
||||||
int oW_border_1 =
|
int oW_border_1 = is_idil_one
|
||||||
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
|
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
|
||||||
int oW_border_2 = std::max(
|
: oW;
|
||||||
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
|
int oW_border_2 = std::max(
|
||||||
int oW_border_3 = oW;
|
oW_border_1,
|
||||||
|
(iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]);
|
||||||
|
int oW_border_3 = oW;
|
||||||
|
|
||||||
for (int n = 0; n < N; ++n) {
|
for (int n = 0; n < N; ++n) {
|
||||||
// Case 1: oh might put us out of bounds
|
// Case 1: oh might put us out of bounds
|
||||||
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
|
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
|
||||||
for (int ow = 0; ow < oW; ++ow) {
|
for (int ow = 0; ow < oW; ++ow) {
|
||||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
} // ow
|
} // ow
|
||||||
} // oh
|
} // oh
|
||||||
|
|
||||||
// Case 2: oh in bounds
|
// Case 2: oh in bounds
|
||||||
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
|
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
|
||||||
// Case a: ow might put us out of bounds
|
// Case a: ow might put us out of bounds
|
||||||
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
|
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
|
||||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
} // ow
|
} // ow
|
||||||
|
|
||||||
// Case b: ow in bounds
|
// Case b: ow in bounds
|
||||||
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
|
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
|
||||||
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
} // ow
|
} // ow
|
||||||
|
|
||||||
// Case c: ow might put us out of bounds
|
// Case c: ow might put us out of bounds
|
||||||
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
|
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
|
||||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
} // ow
|
} // ow
|
||||||
|
|
||||||
} // oh
|
} // oh
|
||||||
|
|
||||||
// Case 3: oh might put us out of bounds
|
// Case 3: oh might put us out of bounds
|
||||||
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
|
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
|
||||||
for (int ow = 0; ow < oW; ++ow) {
|
for (int ow = 0; ow < oW; ++ow) {
|
||||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
} // ow
|
} // ow
|
||||||
} // oh
|
} // oh
|
||||||
|
|
||||||
st_in_ptr += in_stride_N;
|
st_in_ptr += in_stride_N;
|
||||||
st_out_ptr += out_stride_N;
|
st_out_ptr += out_stride_N;
|
||||||
|
|
||||||
} // n
|
} // n
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -351,7 +359,8 @@ void slow_conv_3D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -400,7 +409,8 @@ void slow_conv_3D(
|
|||||||
out_stride_H = out.strides()[2],
|
out_stride_H = out.strides()[2],
|
||||||
out_stride_W = out.strides()[3],
|
out_stride_W = out.strides()[3],
|
||||||
out_stride_O = out.strides()[4],
|
out_stride_O = out.strides()[4],
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -415,9 +425,9 @@ void slow_conv_3D(
|
|||||||
int oh,
|
int oh,
|
||||||
int ow) {
|
int ow) {
|
||||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||||
int id_base = od * wt_strides[0] - padding[0];
|
int id_base = od * wt_strides[0] - padding_lo[0];
|
||||||
int ih_base = oh * wt_strides[1] - padding[1];
|
int ih_base = oh * wt_strides[1] - padding_lo[1];
|
||||||
int iw_base = ow * wt_strides[2] - padding[2];
|
int iw_base = ow * wt_strides[2] - padding_lo[2];
|
||||||
|
|
||||||
for (int o = 0; o < O; ++o) {
|
for (int o = 0; o < O; ++o) {
|
||||||
float r = 0.;
|
float r = 0.;
|
||||||
@@ -478,7 +488,7 @@ void slow_conv_3D(
|
|||||||
std::vector<int> base_w(f_out_jump_w);
|
std::vector<int> base_w(f_out_jump_w);
|
||||||
|
|
||||||
for (int i = 0; i < f_out_jump_d; ++i) {
|
for (int i = 0; i < f_out_jump_d; ++i) {
|
||||||
int id_loop = i * wt_strides[0] - padding[0] + init_d;
|
int id_loop = i * wt_strides[0] - padding_lo[0] + init_d;
|
||||||
|
|
||||||
int wd_base = 0;
|
int wd_base = 0;
|
||||||
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
|
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
|
||||||
@@ -490,7 +500,7 @@ void slow_conv_3D(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||||
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
|
int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h;
|
||||||
|
|
||||||
int wh_base = 0;
|
int wh_base = 0;
|
||||||
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
|
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
|
||||||
@@ -502,7 +512,7 @@ void slow_conv_3D(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||||
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
|
int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w;
|
||||||
|
|
||||||
int ww_base = 0;
|
int ww_base = 0;
|
||||||
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
|
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
|
||||||
@@ -521,9 +531,9 @@ void slow_conv_3D(
|
|||||||
int ow) {
|
int ow) {
|
||||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||||
|
|
||||||
int id_base = od * wt_strides[0] - padding[0];
|
int id_base = od * wt_strides[0] - padding_lo[0];
|
||||||
int ih_base = oh * wt_strides[1] - padding[1];
|
int ih_base = oh * wt_strides[1] - padding_lo[1];
|
||||||
int iw_base = ow * wt_strides[2] - padding[2];
|
int iw_base = ow * wt_strides[2] - padding_lo[2];
|
||||||
|
|
||||||
int wd_base = base_d[od % f_out_jump_d];
|
int wd_base = base_d[od % f_out_jump_d];
|
||||||
int wh_base = base_h[oh % f_out_jump_h];
|
int wh_base = base_h[oh % f_out_jump_h];
|
||||||
@@ -573,24 +583,30 @@ void slow_conv_3D(
|
|||||||
};
|
};
|
||||||
|
|
||||||
int oD_border_0 = 0;
|
int oD_border_0 = 0;
|
||||||
int oD_border_1 =
|
int oD_border_1 = is_idil_one
|
||||||
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
|
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
|
||||||
|
: oD;
|
||||||
int oD_border_2 = std::max(
|
int oD_border_2 = std::max(
|
||||||
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
oD_border_1,
|
||||||
|
(iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
||||||
int oD_border_3 = oD;
|
int oD_border_3 = oD;
|
||||||
|
|
||||||
int oH_border_0 = 0;
|
int oH_border_0 = 0;
|
||||||
int oH_border_1 =
|
int oH_border_1 = is_idil_one
|
||||||
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
|
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
|
||||||
|
: oH;
|
||||||
int oH_border_2 = std::max(
|
int oH_border_2 = std::max(
|
||||||
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
oH_border_1,
|
||||||
|
(iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
||||||
int oH_border_3 = oH;
|
int oH_border_3 = oH;
|
||||||
|
|
||||||
int oW_border_0 = 0;
|
int oW_border_0 = 0;
|
||||||
int oW_border_1 =
|
int oW_border_1 = is_idil_one
|
||||||
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
|
? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2])
|
||||||
|
: oW;
|
||||||
int oW_border_2 = std::max(
|
int oW_border_2 = std::max(
|
||||||
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
oW_border_1,
|
||||||
|
(iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
||||||
int oW_border_3 = oW;
|
int oW_border_3 = oW;
|
||||||
|
|
||||||
for (int n = 0; n < N; ++n) {
|
for (int n = 0; n < N; ++n) {
|
||||||
@@ -658,7 +674,8 @@ void dispatch_slow_conv_1D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -669,7 +686,8 @@ void dispatch_slow_conv_1D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -680,7 +698,8 @@ void dispatch_slow_conv_1D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -691,7 +710,8 @@ void dispatch_slow_conv_1D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -707,7 +727,8 @@ void dispatch_slow_conv_2D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -718,7 +739,8 @@ void dispatch_slow_conv_2D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -729,7 +751,8 @@ void dispatch_slow_conv_2D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -740,7 +763,8 @@ void dispatch_slow_conv_2D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -756,7 +780,8 @@ void dispatch_slow_conv_3D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -767,7 +792,8 @@ void dispatch_slow_conv_3D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -778,7 +804,8 @@ void dispatch_slow_conv_3D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -789,7 +816,8 @@ void dispatch_slow_conv_3D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -829,7 +857,8 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
@@ -848,7 +877,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
|
||||||
// Pad input
|
// Pad input
|
||||||
Shape padded_shape = {N, iH + 2 * padding[0], C};
|
Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C};
|
||||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
|
|
||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
@@ -857,7 +886,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = padding[0] * in_padded.strides()[1];
|
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
|
||||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
in_padded_slice.copy_shared_buffer(
|
in_padded_slice.copy_shared_buffer(
|
||||||
in_padded,
|
in_padded,
|
||||||
@@ -971,7 +1000,8 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
@@ -989,7 +1019,11 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
|
||||||
// Pad input
|
// Pad input
|
||||||
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
Shape padded_shape = {
|
||||||
|
N,
|
||||||
|
iH + padding_lo[0] + padding_hi[0],
|
||||||
|
iW + padding_lo[1] + padding_hi[1],
|
||||||
|
C};
|
||||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
|
|
||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
@@ -998,8 +1032,8 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset =
|
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
|
||||||
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2];
|
padding_lo[1] * in_padded.strides()[2];
|
||||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
in_padded_slice.copy_shared_buffer(
|
in_padded_slice.copy_shared_buffer(
|
||||||
in_padded,
|
in_padded,
|
||||||
@@ -1091,7 +1125,8 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const bool flip,
|
const bool flip,
|
||||||
@@ -1114,7 +1149,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
Shape padded_shape(in.shape().size());
|
Shape padded_shape(in.shape().size());
|
||||||
padded_shape.front() = N;
|
padded_shape.front() = N;
|
||||||
for (size_t i = 0; i < iDim.size(); i++) {
|
for (size_t i = 0; i < iDim.size(); i++) {
|
||||||
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
|
||||||
}
|
}
|
||||||
padded_shape.back() = C;
|
padded_shape.back() = C;
|
||||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
@@ -1125,9 +1160,10 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = 0;
|
size_t data_offset = 0;
|
||||||
for (size_t i = 0; i < padding.size(); i++) {
|
for (size_t i = 0; i < padding_lo.size(); i++) {
|
||||||
data_offset += padding[i] * in_padded.strides()[i + 1];
|
data_offset += padding_lo[i] * in_padded.strides()[i + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
in_padded_slice.copy_shared_buffer(
|
in_padded_slice.copy_shared_buffer(
|
||||||
in_padded,
|
in_padded,
|
||||||
@@ -1261,7 +1297,8 @@ void conv_1D_cpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -1270,22 +1307,40 @@ void conv_1D_cpu(
|
|||||||
const int groups = in.shape().back() / wt.shape().back();
|
const int groups = in.shape().back() / wt.shape().back();
|
||||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
||||||
return explicit_gemm_conv_1D_cpu(
|
return explicit_gemm_conv_1D_cpu(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, stream);
|
in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream);
|
||||||
}
|
}
|
||||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
|
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
|
||||||
return explicit_gemm_conv_ND_cpu(
|
return explicit_gemm_conv_ND_cpu(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
|
wt_strides,
|
||||||
|
wt_dilation,
|
||||||
|
flip,
|
||||||
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
return dispatch_slow_conv_1D(
|
return dispatch_slow_conv_1D(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
|
wt_strides,
|
||||||
|
wt_dilation,
|
||||||
|
in_dilation,
|
||||||
|
flip,
|
||||||
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void conv_2D_cpu(
|
void conv_2D_cpu(
|
||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -1295,18 +1350,35 @@ void conv_2D_cpu(
|
|||||||
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
|
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
|
||||||
in_dilation[1] == 1 && groups == 1) {
|
in_dilation[1] == 1 && groups == 1) {
|
||||||
return explicit_gemm_conv_ND_cpu(
|
return explicit_gemm_conv_ND_cpu(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
|
wt_strides,
|
||||||
|
wt_dilation,
|
||||||
|
flip,
|
||||||
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
return dispatch_slow_conv_2D(
|
return dispatch_slow_conv_2D(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
|
wt_strides,
|
||||||
|
wt_dilation,
|
||||||
|
in_dilation,
|
||||||
|
flip,
|
||||||
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void conv_3D_cpu(
|
void conv_3D_cpu(
|
||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -1317,11 +1389,28 @@ void conv_3D_cpu(
|
|||||||
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
|
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
|
||||||
groups == 1) {
|
groups == 1) {
|
||||||
return explicit_gemm_conv_ND_cpu(
|
return explicit_gemm_conv_ND_cpu(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
|
wt_strides,
|
||||||
|
wt_dilation,
|
||||||
|
flip,
|
||||||
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
return dispatch_slow_conv_3D(
|
return dispatch_slow_conv_3D(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
|
wt_strides,
|
||||||
|
wt_dilation,
|
||||||
|
in_dilation,
|
||||||
|
flip,
|
||||||
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -1338,7 +1427,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_,
|
padding_lo_,
|
||||||
|
padding_hi_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
@@ -1351,7 +1441,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_,
|
padding_lo_,
|
||||||
|
padding_hi_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
@@ -1364,7 +1455,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_,
|
padding_lo_,
|
||||||
|
padding_hi_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
|
|||||||
174
mlx/backend/cpu/eig.cpp
Normal file
174
mlx/backend/cpu/eig.cpp
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cpu/copy.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
|
#include "mlx/backend/cpu/lapack.h"
|
||||||
|
#include "mlx/linalg.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void eig_impl(
|
||||||
|
array& a,
|
||||||
|
array& vectors,
|
||||||
|
array& values,
|
||||||
|
bool compute_eigenvectors,
|
||||||
|
Stream stream) {
|
||||||
|
using OT = std::complex<T>;
|
||||||
|
auto a_ptr = a.data<T>();
|
||||||
|
auto eig_ptr = values.data<OT>();
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_output_array(values);
|
||||||
|
OT* vec_ptr = nullptr;
|
||||||
|
if (compute_eigenvectors) {
|
||||||
|
encoder.set_output_array(vectors);
|
||||||
|
vec_ptr = vectors.data<OT>();
|
||||||
|
}
|
||||||
|
encoder.dispatch([a_ptr,
|
||||||
|
vec_ptr,
|
||||||
|
eig_ptr,
|
||||||
|
compute_eigenvectors,
|
||||||
|
N = vectors.shape(-1),
|
||||||
|
size = vectors.size()]() mutable {
|
||||||
|
// Work query
|
||||||
|
char jobr = 'N';
|
||||||
|
char jobl = compute_eigenvectors ? 'V' : 'N';
|
||||||
|
int n_vecs_r = 1;
|
||||||
|
int n_vecs_l = compute_eigenvectors ? N : 1;
|
||||||
|
int lwork = -1;
|
||||||
|
int info;
|
||||||
|
{
|
||||||
|
T work;
|
||||||
|
int iwork;
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
|
||||||
|
auto vec_tmp_data =
|
||||||
|
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
|
||||||
|
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
||||||
|
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
||||||
|
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||||
|
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
a_ptr,
|
||||||
|
&N,
|
||||||
|
eig_tmp,
|
||||||
|
eig_tmp + N,
|
||||||
|
vec_tmp,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
&info);
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
|
||||||
|
}
|
||||||
|
if (vec_ptr) {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
if (eig_ptr[i].imag() != 0) {
|
||||||
|
// This vector and the next are a pair
|
||||||
|
for (int j = 0; j < N; ++j) {
|
||||||
|
vec_ptr[i * N + j] = {
|
||||||
|
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
|
||||||
|
vec_ptr[(i + 1) * N + j] = {
|
||||||
|
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
} else {
|
||||||
|
for (int j = 0; j < N; ++j) {
|
||||||
|
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vec_ptr += N * N;
|
||||||
|
}
|
||||||
|
a_ptr += N * N;
|
||||||
|
eig_ptr += N;
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream msg;
|
||||||
|
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||||
|
<< info;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
encoder.add_temporary(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Eig::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
const auto& a = inputs[0];
|
||||||
|
auto& values = outputs[0];
|
||||||
|
|
||||||
|
auto vectors = compute_eigenvectors_
|
||||||
|
? outputs[1]
|
||||||
|
: array(a.shape(), complex64, nullptr, {});
|
||||||
|
|
||||||
|
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
|
||||||
|
copy(
|
||||||
|
a,
|
||||||
|
a_copy,
|
||||||
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
stream());
|
||||||
|
|
||||||
|
values.set_data(allocator::malloc(values.nbytes()));
|
||||||
|
|
||||||
|
if (compute_eigenvectors_) {
|
||||||
|
// Set the strides and flags so the eigenvectors
|
||||||
|
// are in the columns of the output
|
||||||
|
auto flags = vectors.flags();
|
||||||
|
auto strides = vectors.strides();
|
||||||
|
auto ndim = a.ndim();
|
||||||
|
std::swap(strides[ndim - 1], strides[ndim - 2]);
|
||||||
|
|
||||||
|
if (a.size() > 1) {
|
||||||
|
flags.row_contiguous = false;
|
||||||
|
if (ndim > 2) {
|
||||||
|
flags.col_contiguous = false;
|
||||||
|
} else {
|
||||||
|
flags.col_contiguous = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vectors.set_data(
|
||||||
|
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
|
||||||
|
}
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case float32:
|
||||||
|
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -12,6 +12,133 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template <typename T, class Enable = void>
|
||||||
|
struct EighWork {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct EighWork<
|
||||||
|
T,
|
||||||
|
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||||
|
using R = T;
|
||||||
|
|
||||||
|
char jobz;
|
||||||
|
char uplo;
|
||||||
|
int N;
|
||||||
|
int lwork;
|
||||||
|
int liwork;
|
||||||
|
int info;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
|
||||||
|
EighWork(char jobz_, char uplo_, int N_)
|
||||||
|
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
|
||||||
|
T work;
|
||||||
|
int iwork;
|
||||||
|
syevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&iwork,
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work);
|
||||||
|
liwork = iwork;
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* vectors, T* values) {
|
||||||
|
syevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
vectors,
|
||||||
|
&N,
|
||||||
|
values,
|
||||||
|
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
static_cast<int*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct EighWork<std::complex<float>> {
|
||||||
|
using T = std::complex<float>;
|
||||||
|
using R = float;
|
||||||
|
|
||||||
|
char jobz;
|
||||||
|
char uplo;
|
||||||
|
int N;
|
||||||
|
int lwork;
|
||||||
|
int lrwork;
|
||||||
|
int liwork;
|
||||||
|
int info;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
|
||||||
|
EighWork(char jobz_, char uplo_, int N_)
|
||||||
|
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
|
||||||
|
T work;
|
||||||
|
R rwork;
|
||||||
|
int iwork;
|
||||||
|
heevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&rwork,
|
||||||
|
&lrwork,
|
||||||
|
&iwork,
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work.real());
|
||||||
|
lrwork = static_cast<int>(rwork);
|
||||||
|
liwork = iwork;
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* vectors, R* values) {
|
||||||
|
heevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
vectors,
|
||||||
|
&N,
|
||||||
|
values,
|
||||||
|
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
&lrwork,
|
||||||
|
static_cast<int*>(buffers[2].buffer.raw_ptr()),
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
if (jobz == 'V') {
|
||||||
|
// We have pre-transposed the vectors but we also must conjugate them
|
||||||
|
// when they are complex.
|
||||||
|
//
|
||||||
|
// We could vectorize this but it is so fast in comparison to heevd that
|
||||||
|
// it doesn't really matter.
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
for (int j = 0; j < N; j++) {
|
||||||
|
*vectors = std::conj(*vectors);
|
||||||
|
vectors++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void eigh_impl(
|
void eigh_impl(
|
||||||
array& vectors,
|
array& vectors,
|
||||||
@@ -19,8 +146,10 @@ void eigh_impl(
|
|||||||
const std::string& uplo,
|
const std::string& uplo,
|
||||||
bool compute_eigenvectors,
|
bool compute_eigenvectors,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
|
using R = typename EighWork<T>::R;
|
||||||
|
|
||||||
auto vec_ptr = vectors.data<T>();
|
auto vec_ptr = vectors.data<T>();
|
||||||
auto eig_ptr = values.data<T>();
|
auto eig_ptr = values.data<R>();
|
||||||
char jobz = compute_eigenvectors ? 'V' : 'N';
|
char jobz = compute_eigenvectors ? 'V' : 'N';
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
@@ -33,49 +162,17 @@ void eigh_impl(
|
|||||||
N = vectors.shape(-1),
|
N = vectors.shape(-1),
|
||||||
size = vectors.size()]() mutable {
|
size = vectors.size()]() mutable {
|
||||||
// Work query
|
// Work query
|
||||||
int lwork = -1;
|
EighWork<T> work(jobz, uplo, N);
|
||||||
int liwork = -1;
|
|
||||||
int info;
|
|
||||||
{
|
|
||||||
T work;
|
|
||||||
int iwork;
|
|
||||||
syevd<T>(
|
|
||||||
&jobz,
|
|
||||||
&uplo,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&work,
|
|
||||||
&lwork,
|
|
||||||
&iwork,
|
|
||||||
&liwork,
|
|
||||||
&info);
|
|
||||||
lwork = static_cast<int>(work);
|
|
||||||
liwork = iwork;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
// Work loop
|
||||||
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
|
|
||||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||||
syevd<T>(
|
work.run(vec_ptr, eig_ptr);
|
||||||
&jobz,
|
|
||||||
&uplo,
|
|
||||||
&N,
|
|
||||||
vec_ptr,
|
|
||||||
&N,
|
|
||||||
eig_ptr,
|
|
||||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
|
||||||
&lwork,
|
|
||||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
|
||||||
&liwork,
|
|
||||||
&info);
|
|
||||||
vec_ptr += N * N;
|
vec_ptr += N * N;
|
||||||
eig_ptr += N;
|
eig_ptr += N;
|
||||||
if (info != 0) {
|
if (work.info != 0) {
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||||
<< info;
|
<< work.info;
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -131,6 +228,10 @@ void Eigh::eval_cpu(
|
|||||||
eigh_impl<double>(
|
eigh_impl<double>(
|
||||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||||
break;
|
break;
|
||||||
|
case complex64:
|
||||||
|
eigh_impl<std::complex<float>>(
|
||||||
|
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[Eigh::eval_cpu] only supports float32 or float64.");
|
"[Eigh::eval_cpu] only supports float32 or float64.");
|
||||||
|
|||||||
@@ -257,15 +257,11 @@ void gather_axis(
|
|||||||
const array& ind,
|
const array& ind,
|
||||||
array& out,
|
array& out,
|
||||||
const int axis) {
|
const int axis) {
|
||||||
auto strides = ind.strides();
|
auto shape = remove_index(ind.shape(), axis);
|
||||||
strides.erase(strides.begin() + axis);
|
ContiguousIterator ind_it(
|
||||||
auto shape = ind.shape();
|
shape, remove_index(ind.strides(), axis), src.ndim() - 1);
|
||||||
shape.erase(shape.begin() + axis);
|
ContiguousIterator src_it(
|
||||||
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
|
shape, remove_index(src.strides(), axis), src.ndim() - 1);
|
||||||
|
|
||||||
strides = src.strides();
|
|
||||||
strides.erase(strides.begin() + axis);
|
|
||||||
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
|
|
||||||
|
|
||||||
auto ind_ptr = ind.data<IdxT>();
|
auto ind_ptr = ind.data<IdxT>();
|
||||||
auto src_ptr = src.data<T>();
|
auto src_ptr = src.data<T>();
|
||||||
@@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
template <typename T, typename IdxT, typename OpT>
|
template <typename T, typename IdxT, typename OpT>
|
||||||
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
||||||
auto strides = idx.strides();
|
auto shape = remove_index(idx.shape(), axis);
|
||||||
strides.erase(strides.begin() + axis);
|
ContiguousIterator idx_it(
|
||||||
auto shape = idx.shape();
|
shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
|
||||||
shape.erase(shape.begin() + axis);
|
ContiguousIterator upd_it(
|
||||||
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
|
shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
|
||||||
|
|
||||||
strides = upd.strides();
|
|
||||||
strides.erase(strides.begin() + axis);
|
|
||||||
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
|
|
||||||
|
|
||||||
auto idx_ptr = idx.data<IdxT>();
|
auto idx_ptr = idx.data<IdxT>();
|
||||||
auto upd_ptr = upd.data<T>();
|
auto upd_ptr = upd.data<T>();
|
||||||
|
|||||||
@@ -2,14 +2,14 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
// Required for Visual Studio.
|
|
||||||
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
|
|
||||||
#ifdef _MSC_VER
|
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#define LAPACK_COMPLEX_CUSTOM
|
#define LAPACK_COMPLEX_CUSTOM
|
||||||
#define lapack_complex_float std::complex<float>
|
#define lapack_complex_float std::complex<float>
|
||||||
#define lapack_complex_double std::complex<double>
|
#define lapack_complex_double std::complex<double>
|
||||||
#endif
|
#define lapack_complex_float_real(z) ((z).real())
|
||||||
|
#define lapack_complex_float_imag(z) ((z).imag())
|
||||||
|
#define lapack_complex_double_real(z) ((z).real())
|
||||||
|
#define lapack_complex_double_imag(z) ((z).imag())
|
||||||
|
|
||||||
#ifdef MLX_USE_ACCELERATE
|
#ifdef MLX_USE_ACCELERATE
|
||||||
#include <Accelerate/Accelerate.h>
|
#include <Accelerate/Accelerate.h>
|
||||||
@@ -32,7 +32,7 @@
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
|
#define INSTANTIATE_LAPACK_REAL(FUNC) \
|
||||||
template <typename T, typename... Args> \
|
template <typename T, typename... Args> \
|
||||||
void FUNC(Args... args) { \
|
void FUNC(Args... args) { \
|
||||||
if constexpr (std::is_same_v<T, float>) { \
|
if constexpr (std::is_same_v<T, float>) { \
|
||||||
@@ -42,11 +42,24 @@
|
|||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_LAPACK_TYPES(geqrf)
|
INSTANTIATE_LAPACK_REAL(geqrf)
|
||||||
INSTANTIATE_LAPACK_TYPES(orgqr)
|
INSTANTIATE_LAPACK_REAL(orgqr)
|
||||||
INSTANTIATE_LAPACK_TYPES(syevd)
|
INSTANTIATE_LAPACK_REAL(syevd)
|
||||||
INSTANTIATE_LAPACK_TYPES(potrf)
|
INSTANTIATE_LAPACK_REAL(geev)
|
||||||
INSTANTIATE_LAPACK_TYPES(gesvdx)
|
INSTANTIATE_LAPACK_REAL(potrf)
|
||||||
INSTANTIATE_LAPACK_TYPES(getrf)
|
INSTANTIATE_LAPACK_REAL(gesvdx)
|
||||||
INSTANTIATE_LAPACK_TYPES(getri)
|
INSTANTIATE_LAPACK_REAL(getrf)
|
||||||
INSTANTIATE_LAPACK_TYPES(trtri)
|
INSTANTIATE_LAPACK_REAL(getri)
|
||||||
|
INSTANTIATE_LAPACK_REAL(trtri)
|
||||||
|
|
||||||
|
#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \
|
||||||
|
template <typename T, typename... Args> \
|
||||||
|
void FUNC(Args... args) { \
|
||||||
|
if constexpr (std::is_same_v<T, std::complex<float>>) { \
|
||||||
|
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
|
||||||
|
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
||||||
|
|||||||
@@ -132,6 +132,10 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[AddMM::eval_cpu] Currently only supports float32.");
|
"[AddMM::eval_cpu] Currently only supports float32.");
|
||||||
}
|
}
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Fill output with C
|
// Fill output with C
|
||||||
auto& c = inputs[2];
|
auto& c = inputs[2];
|
||||||
@@ -139,7 +143,9 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
? CopyType::Scalar
|
? CopyType::Scalar
|
||||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||||
copy(c, out, ctype, stream());
|
copy(c, out, ctype, stream());
|
||||||
|
if (inputs[0].shape(-1) == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
|
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,9 +13,18 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
|
||||||
|
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) {
|
||||||
|
auto power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
|
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, int bits>
|
template <typename T, int bits>
|
||||||
void extract_bits(const uint8_t* w_in, T* w_out) {
|
void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||||
assert(bits == 3 || bits == 6);
|
static_assert(bits == 3 || bits == 5 || bits == 6);
|
||||||
if (bits == 3) {
|
if (bits == 3) {
|
||||||
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
||||||
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
||||||
@@ -25,6 +34,16 @@ void extract_bits(const uint8_t* w_in, T* w_out) {
|
|||||||
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
|
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
|
||||||
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
|
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
|
||||||
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
|
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
|
||||||
|
} else if (bits == 5) {
|
||||||
|
w_out[0] = static_cast<T>(w_in[0] & 0x1f);
|
||||||
|
w_out[1] = static_cast<T>(((w_in[0] & 0xe0) >> 5) + ((w_in[1] & 0x3) << 3));
|
||||||
|
w_out[2] = static_cast<T>((w_in[1] & 0x7c) >> 2);
|
||||||
|
w_out[3] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0xf) << 1));
|
||||||
|
w_out[4] = static_cast<T>(((w_in[2] & 0xf0) >> 4) + ((w_in[3] & 0x1) << 4));
|
||||||
|
w_out[5] = static_cast<T>((w_in[3] & 0x3e) >> 1);
|
||||||
|
w_out[6] = static_cast<T>(((w_in[3] & 0xc0) >> 6) + ((w_in[4] & 0x7) << 2));
|
||||||
|
w_out[7] = static_cast<T>((w_in[4] & 0xf8) >> 3);
|
||||||
|
|
||||||
} else if (bits == 6) {
|
} else if (bits == 6) {
|
||||||
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
|
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
|
||||||
w_out[1] =
|
w_out[1] =
|
||||||
@@ -46,8 +65,8 @@ void _qmm(
|
|||||||
int N,
|
int N,
|
||||||
int K) {
|
int K) {
|
||||||
constexpr int bitmask = (1 << bits) - 1;
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
constexpr int pack_factor = get_pack_factor(bits, 8);
|
||||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
|
||||||
constexpr int packs_in_group = group_size / pack_factor;
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
|
|
||||||
for (int m = 0; m < M; m++) {
|
for (int m = 0; m < M; m++) {
|
||||||
@@ -65,7 +84,7 @@ void _qmm(
|
|||||||
T scale = *scales_local++;
|
T scale = *scales_local++;
|
||||||
T bias = *biases_local++;
|
T bias = *biases_local++;
|
||||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||||
if (bits == 3 || bits == 6) {
|
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||||
T wl[pack_factor];
|
T wl[pack_factor];
|
||||||
extract_bits<T, bits>(w_local, wl);
|
extract_bits<T, bits>(w_local, wl);
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
@@ -104,8 +123,9 @@ void _qmm_t(
|
|||||||
int N,
|
int N,
|
||||||
int K) {
|
int K) {
|
||||||
constexpr int bitmask = (1 << bits) - 1;
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
|
||||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
constexpr int pack_factor = get_pack_factor(bits, 8);
|
||||||
|
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
|
||||||
constexpr int packs_in_group = group_size / pack_factor;
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
|
|
||||||
for (int m = 0; m < M; m++) {
|
for (int m = 0; m < M; m++) {
|
||||||
@@ -121,7 +141,7 @@ void _qmm_t(
|
|||||||
T bias = *biases_local++;
|
T bias = *biases_local++;
|
||||||
|
|
||||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||||
if (bits == 3 || bits == 6) {
|
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||||
T wl[pack_factor];
|
T wl[pack_factor];
|
||||||
extract_bits<T, bits>(w_local, wl);
|
extract_bits<T, bits>(w_local, wl);
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
@@ -304,6 +324,10 @@ void _qmm_dispatch_typed(
|
|||||||
_qmm_dispatch_group<T, 4>(
|
_qmm_dispatch_group<T, 4>(
|
||||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||||
break;
|
break;
|
||||||
|
case 5:
|
||||||
|
_qmm_dispatch_group<T, 5>(
|
||||||
|
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||||
|
break;
|
||||||
case 6:
|
case 6:
|
||||||
_qmm_dispatch_group<T, 6>(
|
_qmm_dispatch_group<T, 6>(
|
||||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||||
@@ -613,9 +637,8 @@ void quantize(
|
|||||||
float eps = 1e-7;
|
float eps = 1e-7;
|
||||||
|
|
||||||
bool power_of_2_bits = is_power_of_2(bits);
|
bool power_of_2_bits = is_power_of_2(bits);
|
||||||
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
int el_per_int = get_pack_factor(bits, 32);
|
||||||
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
|
int bytes_per_pack = get_bytes_per_pack(bits);
|
||||||
int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
||||||
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
||||||
size_t n_groups = w_size / group_size;
|
size_t n_groups = w_size / group_size;
|
||||||
|
|
||||||
@@ -640,15 +663,21 @@ void quantize(
|
|||||||
}
|
}
|
||||||
size_t out_idx = i * int_per_group;
|
size_t out_idx = i * int_per_group;
|
||||||
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
|
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
|
||||||
uint32_t out_el = 0;
|
uint64_t out_el = 0;
|
||||||
for (int k = 0; k < el_per_int; ++k) {
|
for (int k = 0; k < el_per_int; ++k) {
|
||||||
float w_el = w[w_idx + j * el_per_int + k];
|
float w_el = w[w_idx + j * el_per_int + k];
|
||||||
w_el = std::rint((w_el - bias) / scale);
|
w_el = std::rint((w_el - bias) / scale);
|
||||||
w_el = std::min(std::max(w_el, 0.0f), n_bins);
|
w_el = std::min(std::max(w_el, 0.0f), n_bins);
|
||||||
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
out_el |= static_cast<uint64_t>(w_el) << (k * bits);
|
||||||
}
|
}
|
||||||
if (power_of_2_bits) {
|
if (power_of_2_bits) {
|
||||||
out[out_idx + j] = out_el;
|
out[out_idx + j] = out_el;
|
||||||
|
} else if (bits == 5) {
|
||||||
|
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||||
|
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||||
|
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
|
||||||
|
out[out_idx + bytes_per_pack * j + 3] = (out_el & 0xff000000) >> 24;
|
||||||
|
out[out_idx + bytes_per_pack * j + 4] = (out_el & 0xff00000000) >> 32;
|
||||||
} else {
|
} else {
|
||||||
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||||
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||||
|
|||||||
@@ -2,32 +2,13 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/backend/common/unary.h"
|
||||||
#include "mlx/array.h"
|
|
||||||
#include "mlx/backend/common/utils.h"
|
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/backend/cpu/simd/simd.h"
|
#include "mlx/backend/cpu/simd/simd.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void set_unary_output_data(const array& in, array& out) {
|
|
||||||
if (in.flags().contiguous) {
|
|
||||||
if (is_donatable(in, out)) {
|
|
||||||
out.copy_shared_buffer(in);
|
|
||||||
} else {
|
|
||||||
auto size = in.data_size();
|
|
||||||
out.set_data(
|
|
||||||
allocator::malloc(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = T, typename Op>
|
template <typename T, typename U = T, typename Op>
|
||||||
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
||||||
for (size_t i = 0; i < shape; i += 1) {
|
for (size_t i = 0; i < shape; i += 1) {
|
||||||
|
|||||||
@@ -6,26 +6,46 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||||
|
|
||||||
target_compile_definitions(mlx PUBLIC MLX_USE_CUDA)
|
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||||
|
|
||||||
# Enable defining device lambda functions.
|
# Enable defining device lambda functions.
|
||||||
target_compile_options(mlx
|
target_compile_options(mlx
|
||||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||||
|
|
||||||
|
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
|
||||||
|
# Explicitly pass this flag to suppress the warning, it is safe to set it to
|
||||||
|
# true but the warning wouldn't be suppressed.
|
||||||
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
|
||||||
|
target_compile_options(
|
||||||
|
mlx
|
||||||
|
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--static-global-template-stub=false>")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||||
set(MLX_CUDA_ARCHITECTURES
|
set(MLX_CUDA_ARCHITECTURES
|
||||||
"75;80"
|
"70;80"
|
||||||
CACHE STRING "CUDA architectures")
|
CACHE STRING "CUDA architectures")
|
||||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||||
@@ -36,7 +56,7 @@ FetchContent_Declare(
|
|||||||
cccl
|
cccl
|
||||||
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
|
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
|
||||||
FetchContent_MakeAvailable(cccl)
|
FetchContent_MakeAvailable(cccl)
|
||||||
target_include_directories(mlx PRIVATE BEFORE "${cccl_SOURCE_DIR}/include")
|
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
|
||||||
|
|
||||||
# Use fixed version of NVTX.
|
# Use fixed version of NVTX.
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
@@ -52,6 +72,9 @@ target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
|
|||||||
find_package(CUDAToolkit REQUIRED)
|
find_package(CUDAToolkit REQUIRED)
|
||||||
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
||||||
|
|
||||||
|
# Use cublasLt.
|
||||||
|
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
||||||
|
|
||||||
# Suppress nvcc warnings on MLX headers.
|
# Suppress nvcc warnings on MLX headers.
|
||||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||||
--diag_suppress=997>)
|
--diag_suppress=997>)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
@@ -13,24 +14,50 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
CudaAllocator::CudaAllocator() {
|
CudaAllocator::CudaAllocator()
|
||||||
|
: buffer_cache_(
|
||||||
|
getpagesize(),
|
||||||
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
|
[this](CudaBuffer* buf) {
|
||||||
|
cuda_free(buf->data);
|
||||||
|
delete buf;
|
||||||
|
}) {
|
||||||
// TODO: Set memory limit for multi-device.
|
// TODO: Set memory limit for multi-device.
|
||||||
size_t free, total;
|
size_t free, total;
|
||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||||
memory_limit_ = total * 0.8;
|
memory_limit_ = total * 0.8;
|
||||||
|
max_pool_size_ = memory_limit_;
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer CudaAllocator::malloc(size_t size) {
|
Buffer CudaAllocator::malloc(size_t size) {
|
||||||
// TODO: Check memory limit.
|
// Find available buffer from cache.
|
||||||
auto* buf = new CudaBuffer{nullptr, size};
|
std::unique_lock lock(mutex_);
|
||||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
if (!buf) {
|
||||||
throw std::runtime_error(
|
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||||
fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
// try to reclaim memory from the cache.
|
||||||
|
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
||||||
|
if (mem_required >= memory_limit_) {
|
||||||
|
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
|
||||||
|
}
|
||||||
|
|
||||||
|
lock.unlock();
|
||||||
|
buf = new CudaBuffer{nullptr, size};
|
||||||
|
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||||
|
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||||
|
}
|
||||||
|
lock.lock();
|
||||||
}
|
}
|
||||||
std::lock_guard lock(mutex_);
|
|
||||||
active_memory_ += size;
|
active_memory_ += size;
|
||||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||||
|
|
||||||
|
// Maintain the cache below the requested limit.
|
||||||
|
if (get_cache_memory() > max_pool_size_) {
|
||||||
|
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||||
|
}
|
||||||
|
|
||||||
return Buffer{buf};
|
return Buffer{buf};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,26 +67,15 @@ void CudaAllocator::free(Buffer buffer) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If free() is called from a unregistered thread, reschedule the call to
|
std::unique_lock lock(mutex_);
|
||||||
// worker.
|
active_memory_ -= buf->size;
|
||||||
{
|
if (get_cache_memory() < max_pool_size_) {
|
||||||
std::lock_guard lock(worker_mutex_);
|
buffer_cache_.recycle_to_cache(buf);
|
||||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
} else {
|
||||||
if (!worker_) {
|
lock.unlock();
|
||||||
worker_.reset(new Worker);
|
cuda_free(buf->data);
|
||||||
}
|
delete buf;
|
||||||
worker_->add_task([buffer]() { allocator().free(buffer); });
|
|
||||||
worker_->end_batch();
|
|
||||||
worker_->commit();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t size = buf->size;
|
|
||||||
cudaFree(buf->data);
|
|
||||||
delete buf;
|
|
||||||
std::lock_guard lock(mutex_);
|
|
||||||
active_memory_ -= size;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t CudaAllocator::size(Buffer buffer) const {
|
size_t CudaAllocator::size(Buffer buffer) const {
|
||||||
@@ -75,6 +91,25 @@ void CudaAllocator::register_this_thread() {
|
|||||||
allowed_threads_.insert(std::this_thread::get_id());
|
allowed_threads_.insert(std::this_thread::get_id());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CudaAllocator::cuda_free(void* buf) {
|
||||||
|
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
||||||
|
// worker.
|
||||||
|
{
|
||||||
|
std::lock_guard lock(worker_mutex_);
|
||||||
|
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||||
|
if (!worker_) {
|
||||||
|
worker_.reset(new Worker);
|
||||||
|
}
|
||||||
|
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
||||||
|
worker_->end_batch();
|
||||||
|
worker_->commit();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaFree(buf);
|
||||||
|
}
|
||||||
|
|
||||||
size_t CudaAllocator::get_active_memory() const {
|
size_t CudaAllocator::get_active_memory() const {
|
||||||
return active_memory_;
|
return active_memory_;
|
||||||
}
|
}
|
||||||
@@ -98,6 +133,21 @@ size_t CudaAllocator::set_memory_limit(size_t limit) {
|
|||||||
return limit;
|
return limit;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t CudaAllocator::get_cache_memory() const {
|
||||||
|
return buffer_cache_.cache_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CudaAllocator::set_cache_limit(size_t limit) {
|
||||||
|
std::lock_guard lk(mutex_);
|
||||||
|
std::swap(limit, max_pool_size_);
|
||||||
|
return limit;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudaAllocator::clear_cache() {
|
||||||
|
std::lock_guard lk(mutex_);
|
||||||
|
buffer_cache_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
CudaAllocator& allocator() {
|
CudaAllocator& allocator() {
|
||||||
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
||||||
// will not be called on exit and buffers in the cache will be leaked. This
|
// will not be called on exit and buffers in the cache will be leaked. This
|
||||||
@@ -138,17 +188,19 @@ size_t set_memory_limit(size_t limit) {
|
|||||||
size_t get_memory_limit() {
|
size_t get_memory_limit() {
|
||||||
return cu::allocator().get_memory_limit();
|
return cu::allocator().get_memory_limit();
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Implement buffer cache.
|
|
||||||
size_t get_cache_memory() {
|
size_t get_cache_memory() {
|
||||||
return 0;
|
return cu::allocator().get_cache_memory();
|
||||||
}
|
}
|
||||||
size_t set_cache_limit(size_t) {
|
size_t set_cache_limit(size_t limit) {
|
||||||
return 0;
|
return cu::allocator().set_cache_limit(limit);
|
||||||
}
|
}
|
||||||
|
void clear_cache() {
|
||||||
|
cu::allocator().clear_cache();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not supported in CUDA.
|
||||||
size_t set_wired_limit(size_t) {
|
size_t set_wired_limit(size_t) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
void clear_cache() {}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/buffer_cache.h"
|
||||||
|
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <set>
|
#include <set>
|
||||||
@@ -33,11 +34,17 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
// buffers there would result in dead lock.
|
// buffers there would result in dead lock.
|
||||||
void register_this_thread();
|
void register_this_thread();
|
||||||
|
|
||||||
|
// Call cudaFree in the safe thread.
|
||||||
|
void cuda_free(void* buf);
|
||||||
|
|
||||||
size_t get_active_memory() const;
|
size_t get_active_memory() const;
|
||||||
size_t get_peak_memory() const;
|
size_t get_peak_memory() const;
|
||||||
void reset_peak_memory();
|
void reset_peak_memory();
|
||||||
size_t get_memory_limit();
|
size_t get_memory_limit();
|
||||||
size_t set_memory_limit(size_t limit);
|
size_t set_memory_limit(size_t limit);
|
||||||
|
size_t get_cache_memory() const;
|
||||||
|
size_t set_cache_limit(size_t limit);
|
||||||
|
void clear_cache();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CudaAllocator();
|
CudaAllocator();
|
||||||
@@ -49,6 +56,8 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
|
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
size_t memory_limit_;
|
size_t memory_limit_;
|
||||||
|
size_t max_pool_size_;
|
||||||
|
BufferCache<CudaBuffer> buffer_cache_;
|
||||||
size_t active_memory_{0};
|
size_t active_memory_{0};
|
||||||
size_t peak_memory_{0};
|
size_t peak_memory_{0};
|
||||||
};
|
};
|
||||||
|
|||||||
305
mlx/backend/cuda/binary.cu
Normal file
305
mlx/backend/cuda/binary.cu
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/binary.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/binary_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
out[index] = Op{}(a[0], b[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
out[index] = Op{}(a[0], b[index]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
out[index] = Op{}(a[index], b[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
out[index] = Op{}(a[index], b[index]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||||
|
__global__ void binary_g_nd(
|
||||||
|
const In* a,
|
||||||
|
const In* b,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||||
|
index, shape.data(), a_strides.data(), b_strides.data());
|
||||||
|
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void binary_g(
|
||||||
|
const In* a,
|
||||||
|
const In* b,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ Shape shape,
|
||||||
|
const __grid_constant__ Strides a_strides,
|
||||||
|
const __grid_constant__ Strides b_strides,
|
||||||
|
int ndim) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [a_idx, b_idx] = elem_to_loc_4d(
|
||||||
|
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||||
|
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out>
|
||||||
|
constexpr bool supports_binary_op() {
|
||||||
|
if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||
|
||||||
|
std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||
|
||||||
|
std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||
|
||||||
|
std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {
|
||||||
|
return std::is_same_v<In, Out>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||
|
||||||
|
std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||
|
||||||
|
std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {
|
||||||
|
return std::is_same_v<Out, bool>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {
|
||||||
|
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, NaNEqual>) {
|
||||||
|
return std::is_same_v<Out, bool> &&
|
||||||
|
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, LogAddExp> || std::is_same_v<Op, ArcTan2>) {
|
||||||
|
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
|
||||||
|
std::is_same_v<Op, BitwiseXor>) {
|
||||||
|
return std::is_same_v<In, Out> && std::is_integral_v<In>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {
|
||||||
|
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||||
|
!std::is_same_v<In, bool>;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_op_gpu_inplace(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
std::string_view op,
|
||||||
|
const Stream& s) {
|
||||||
|
assert(inputs.size() > 1);
|
||||||
|
const auto& a = inputs[0];
|
||||||
|
const auto& b = inputs[1];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
|
||||||
|
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
||||||
|
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||||
|
using InType = cuda_type_t<CTYPE_IN>;
|
||||||
|
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||||
|
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
if (bopt == BinaryOpType::General) {
|
||||||
|
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||||
|
auto& a_strides = strides[0];
|
||||||
|
auto& b_strides = strides[1];
|
||||||
|
bool large = a.data_size() > UINT32_MAX ||
|
||||||
|
b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||||
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
|
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||||
|
int ndim = shape.size();
|
||||||
|
if (ndim <= 3) {
|
||||||
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
|
auto kernel =
|
||||||
|
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, out, large);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
a.data<InType>(),
|
||||||
|
b.data<InType>(),
|
||||||
|
out.data<OutType>(),
|
||||||
|
out.data_size(),
|
||||||
|
const_param<NDIM>(shape),
|
||||||
|
const_param<NDIM>(a_strides),
|
||||||
|
const_param<NDIM>(b_strides));
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, out, large);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
a.data<InType>(),
|
||||||
|
b.data<InType>(),
|
||||||
|
out.data<OutType>(),
|
||||||
|
out.data_size(),
|
||||||
|
const_param(shape),
|
||||||
|
const_param(a_strides),
|
||||||
|
const_param(b_strides),
|
||||||
|
ndim);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||||
|
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||||
|
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||||
|
if (bopt == BinaryOpType::ScalarVector) {
|
||||||
|
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||||
|
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||||
|
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
||||||
|
} else if (bopt == BinaryOpType::VectorVector) {
|
||||||
|
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||||
|
}
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, out, LARGE);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
a.data<InType>(),
|
||||||
|
b.data<InType>(),
|
||||||
|
out.data<OutType>(),
|
||||||
|
out.data_size());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||||
|
op,
|
||||||
|
dtype_to_string(a.dtype()),
|
||||||
|
dtype_to_string(out.dtype())));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_op_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
std::string_view op,
|
||||||
|
const Stream& s) {
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||||
|
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||||
|
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_op_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
std::string_view op,
|
||||||
|
const Stream& s) {
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
std::vector<array> outputs{out};
|
||||||
|
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define BINARY_GPU(func) \
|
||||||
|
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||||
|
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||||
|
auto& s = out.primitive().stream(); \
|
||||||
|
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define BINARY_GPU_MULTI(func) \
|
||||||
|
void func::eval_gpu( \
|
||||||
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
|
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||||
|
auto& s = outputs[0].primitive().stream(); \
|
||||||
|
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
|
||||||
|
}
|
||||||
|
|
||||||
|
BINARY_GPU(Add)
|
||||||
|
BINARY_GPU(ArcTan2)
|
||||||
|
BINARY_GPU(Divide)
|
||||||
|
BINARY_GPU(Remainder)
|
||||||
|
BINARY_GPU(Equal)
|
||||||
|
BINARY_GPU(Greater)
|
||||||
|
BINARY_GPU(GreaterEqual)
|
||||||
|
BINARY_GPU(Less)
|
||||||
|
BINARY_GPU(LessEqual)
|
||||||
|
BINARY_GPU(LogicalAnd)
|
||||||
|
BINARY_GPU(LogicalOr)
|
||||||
|
BINARY_GPU(LogAddExp)
|
||||||
|
BINARY_GPU(Maximum)
|
||||||
|
BINARY_GPU(Minimum)
|
||||||
|
BINARY_GPU(Multiply)
|
||||||
|
BINARY_GPU(NotEqual)
|
||||||
|
BINARY_GPU(Power)
|
||||||
|
BINARY_GPU(Subtract)
|
||||||
|
|
||||||
|
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
auto op = get_primitive_string(this);
|
||||||
|
switch (op_) {
|
||||||
|
case BitwiseBinary::And:
|
||||||
|
binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s);
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::Or:
|
||||||
|
binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s);
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::Xor:
|
||||||
|
binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s);
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::LeftShift:
|
||||||
|
binary_op_gpu<cu::LeftShift>(inputs, out, op, s);
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::RightShift:
|
||||||
|
binary_op_gpu<cu::RightShift>(inputs, out, op, s);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/gpu/copy.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
void copy_gpu_inplace(
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
const Shape& data_shape,
|
|
||||||
const Strides& strides_in_pre,
|
|
||||||
const Strides& strides_out_pre,
|
|
||||||
int64_t inp_offset,
|
|
||||||
int64_t out_offset,
|
|
||||||
CopyType ctype,
|
|
||||||
const Stream& s,
|
|
||||||
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
|
|
||||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
|
||||||
throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend.");
|
|
||||||
}
|
|
||||||
|
|
||||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
|
||||||
throw std::runtime_error("fill_gpu not implemented in CUDA backend.");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
89
mlx/backend/cuda/copy.cu
Normal file
89
mlx/backend/cuda/copy.cu
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void copy_gpu_inplace(
|
||||||
|
const array& in_,
|
||||||
|
array& out,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides_in,
|
||||||
|
const Strides& strides_out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out,
|
||||||
|
CopyType ctype,
|
||||||
|
const Stream& s,
|
||||||
|
const std::optional<array>& dynamic_offset_in,
|
||||||
|
const std::optional<array>& dynamic_offset_out) {
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const array& in = in_.data_shared_ptr() ? in_ : out;
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
||||||
|
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||||
|
auto [shape_collapsed, strides_vec] = collapse_contiguous_dims(
|
||||||
|
shape, std::vector{strides_in, strides_out}, INT32_MAX);
|
||||||
|
if (ctype == CopyType::General) {
|
||||||
|
copy_general_input(
|
||||||
|
encoder,
|
||||||
|
ctype,
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
offset_in,
|
||||||
|
offset_out,
|
||||||
|
shape_collapsed,
|
||||||
|
strides_vec[0]);
|
||||||
|
} else {
|
||||||
|
if (dynamic_offset_in || dynamic_offset_out) {
|
||||||
|
copy_general_dynamic(
|
||||||
|
encoder,
|
||||||
|
ctype,
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
offset_in,
|
||||||
|
offset_out,
|
||||||
|
shape_collapsed,
|
||||||
|
strides_vec[0],
|
||||||
|
strides_vec[1],
|
||||||
|
dynamic_offset_in ? *dynamic_offset_in : array(0, int64),
|
||||||
|
dynamic_offset_out ? *dynamic_offset_out : array(0, int64));
|
||||||
|
} else {
|
||||||
|
copy_general(
|
||||||
|
encoder,
|
||||||
|
ctype,
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
offset_in,
|
||||||
|
offset_out,
|
||||||
|
shape_collapsed,
|
||||||
|
strides_vec[0],
|
||||||
|
strides_vec[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill_gpu(const array& in, array& out, const Stream& s) {
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
71
mlx/backend/cuda/copy/copy.cuh
Normal file
71
mlx/backend/cuda/copy/copy.cuh
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
||||||
|
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
||||||
|
using InType = cuda_type_t<CTYPE_IN>; \
|
||||||
|
using OutType = cuda_type_t<CTYPE_OUT>; \
|
||||||
|
if constexpr (cu::CastOp<InType, OutType>::is_castable) { \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else { \
|
||||||
|
throw std::runtime_error(fmt::format( \
|
||||||
|
"Can not copy data from dtype {} to {}.", \
|
||||||
|
dtype_to_string(out.dtype()), \
|
||||||
|
dtype_to_string(in.dtype()))); \
|
||||||
|
} \
|
||||||
|
}); \
|
||||||
|
})
|
||||||
|
|
||||||
|
void copy_contiguous(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out);
|
||||||
|
|
||||||
|
void copy_general(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides_in,
|
||||||
|
const Strides& strides_out);
|
||||||
|
|
||||||
|
void copy_general_dynamic(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides_in,
|
||||||
|
const Strides& strides_out,
|
||||||
|
const array& dynamic_offset_in,
|
||||||
|
const array& dynamic_offset_out);
|
||||||
|
|
||||||
|
void copy_general_input(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides_in);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
56
mlx/backend/cuda/copy/copy_contiguous.cu
Normal file
56
mlx/backend/cuda/copy/copy_contiguous.cu
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void copy_s(const In* in, Out* out, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
out[index] = CastOp<In, Out>{}(in[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void copy_v(const In* in, Out* out, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
out[index] = CastOp<In, Out>{}(in[index]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void copy_contiguous(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t in_offset,
|
||||||
|
int64_t out_offset) {
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||||
|
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||||
|
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||||
|
auto kernel = cu::copy_s<InType, OutType, IdxT>;
|
||||||
|
if (ctype == CopyType::Vector) {
|
||||||
|
kernel = cu::copy_v<InType, OutType, IdxT>;
|
||||||
|
}
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
in.data<InType>() + in_offset,
|
||||||
|
out.data<OutType>() + out_offset,
|
||||||
|
out.data_size());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
95
mlx/backend/cuda/copy/copy_general.cu
Normal file
95
mlx/backend/cuda/copy/copy_general.cu
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||||
|
__global__ void copy_gg_nd(
|
||||||
|
const In* in,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
||||||
|
index, shape.data(), strides_in.data(), strides_out.data());
|
||||||
|
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void copy_gg(
|
||||||
|
const In* in,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ Shape shape,
|
||||||
|
const __grid_constant__ Strides strides_in,
|
||||||
|
const __grid_constant__ Strides strides_out,
|
||||||
|
int ndim) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [idx_in, idx_out] = elem_to_loc_4d(
|
||||||
|
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||||
|
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void copy_general(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides_in,
|
||||||
|
const Strides& strides_out) {
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||||
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
|
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||||
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
|
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||||
|
int ndim = shape.size();
|
||||||
|
if (ndim <= 3) {
|
||||||
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
|
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
out.data_size(),
|
||||||
|
const_param<NDIM>(shape),
|
||||||
|
const_param<NDIM>(strides_in),
|
||||||
|
const_param<NDIM>(strides_out));
|
||||||
|
});
|
||||||
|
} else { // ndim >= 4
|
||||||
|
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
out.data_size(),
|
||||||
|
const_param(shape),
|
||||||
|
const_param(strides_in),
|
||||||
|
const_param(strides_out),
|
||||||
|
ndim);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
105
mlx/backend/cuda/copy/copy_general_dynamic.cu
Normal file
105
mlx/backend/cuda/copy/copy_general_dynamic.cu
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||||
|
__global__ void copy_gg_dynamic_nd(
|
||||||
|
const In* in,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out,
|
||||||
|
const int64_t* offset_in,
|
||||||
|
const int64_t* offset_out) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
||||||
|
index, shape.data(), strides_in.data(), strides_out.data());
|
||||||
|
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void copy_gg_dynamic(
|
||||||
|
const In* in,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ Shape shape,
|
||||||
|
const __grid_constant__ Strides strides_in,
|
||||||
|
const __grid_constant__ Strides strides_out,
|
||||||
|
int ndim,
|
||||||
|
const int64_t* offset_in,
|
||||||
|
const int64_t* offset_out) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [idx_in, idx_out] = elem_to_loc_4d(
|
||||||
|
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||||
|
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void copy_general_dynamic(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides_in,
|
||||||
|
const Strides& strides_out,
|
||||||
|
const array& dynamic_offset_in,
|
||||||
|
const array& dynamic_offset_out) {
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||||
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
|
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||||
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
|
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||||
|
int ndim = shape.size();
|
||||||
|
if (ndim <= 3) {
|
||||||
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
|
auto kernel = cu::copy_gg_dynamic_nd<InType, OutType, IdxT, NDIM>;
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
out.data_size(),
|
||||||
|
const_param<NDIM>(shape),
|
||||||
|
const_param<NDIM>(strides_in),
|
||||||
|
const_param<NDIM>(strides_out),
|
||||||
|
dynamic_offset_in.data<int64_t>(),
|
||||||
|
dynamic_offset_out.data<int64_t>());
|
||||||
|
});
|
||||||
|
} else { // ndim >= 4
|
||||||
|
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
out.data_size(),
|
||||||
|
const_param(shape),
|
||||||
|
const_param(strides_in),
|
||||||
|
const_param(strides_out),
|
||||||
|
ndim,
|
||||||
|
dynamic_offset_in.data<int64_t>(),
|
||||||
|
dynamic_offset_out.data<int64_t>());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
88
mlx/backend/cuda/copy/copy_general_input.cu
Normal file
88
mlx/backend/cuda/copy/copy_general_input.cu
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||||
|
__global__ void copy_g_nd(
|
||||||
|
const In* in,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data());
|
||||||
|
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void copy_g(
|
||||||
|
const In* in,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ Shape shape,
|
||||||
|
const __grid_constant__ Strides strides_in,
|
||||||
|
int ndim) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim);
|
||||||
|
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void copy_general_input(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides_in) {
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||||
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
|
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||||
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
|
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||||
|
int ndim = shape.size();
|
||||||
|
if (ndim <= 3) {
|
||||||
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
|
auto kernel = cu::copy_g_nd<InType, OutType, IdxT, NDIM>;
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
out.data_size(),
|
||||||
|
const_param<NDIM>(shape),
|
||||||
|
const_param<NDIM>(strides_in));
|
||||||
|
});
|
||||||
|
} else { // ndim >= 4
|
||||||
|
auto kernel = cu::copy_g<InType, OutType, IdxT>;
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
out.data_size(),
|
||||||
|
const_param(shape),
|
||||||
|
const_param(strides_in),
|
||||||
|
ndim);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
11
mlx/backend/cuda/cuda.cpp
Normal file
11
mlx/backend/cuda/cuda.cpp
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
10
mlx/backend/cuda/cuda.h
Normal file
10
mlx/backend/cuda/cuda.h
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
/* Check if the CUDA backend is available. */
|
||||||
|
bool is_available();
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
@@ -34,14 +34,26 @@ CommandEncoder& DeviceStream::get_encoder() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Device::Device(int device) : device_(device) {
|
Device::Device(int device) : device_(device) {
|
||||||
|
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||||
|
&compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_));
|
||||||
|
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||||
|
&compute_capability_minor_, cudaDevAttrComputeCapabilityMinor, device_));
|
||||||
// Validate the requirements of device.
|
// Validate the requirements of device.
|
||||||
int attr = 0;
|
int attr = 0;
|
||||||
cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_);
|
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||||
|
&attr, cudaDevAttrConcurrentManagedAccess, device_));
|
||||||
if (attr != 1) {
|
if (attr != 1) {
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
"Device {} does not support synchronization in managed memory.",
|
"Device {} does not support synchronization in managed memory.",
|
||||||
device_));
|
device_));
|
||||||
}
|
}
|
||||||
|
// The cublasLt handle is used by matmul.
|
||||||
|
make_current();
|
||||||
|
cublasLtCreate(<_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Device::~Device() {
|
||||||
|
cublasLtDestroy(lt_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::make_current() {
|
void Device::make_current() {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
|
#include <cublasLt.h>
|
||||||
#include <thrust/execution_policy.h>
|
#include <thrust/execution_policy.h>
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
@@ -46,6 +47,7 @@ class DeviceStream {
|
|||||||
class Device {
|
class Device {
|
||||||
public:
|
public:
|
||||||
explicit Device(int device);
|
explicit Device(int device);
|
||||||
|
~Device();
|
||||||
|
|
||||||
Device(const Device&) = delete;
|
Device(const Device&) = delete;
|
||||||
Device& operator=(const Device&) = delete;
|
Device& operator=(const Device&) = delete;
|
||||||
@@ -58,9 +60,21 @@ class Device {
|
|||||||
int cuda_device() const {
|
int cuda_device() const {
|
||||||
return device_;
|
return device_;
|
||||||
}
|
}
|
||||||
|
int compute_capability_major() const {
|
||||||
|
return compute_capability_major_;
|
||||||
|
}
|
||||||
|
int compute_capability_minor() const {
|
||||||
|
return compute_capability_minor_;
|
||||||
|
}
|
||||||
|
cublasLtHandle_t lt_handle() const {
|
||||||
|
return lt_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int device_;
|
int device_;
|
||||||
|
int compute_capability_major_;
|
||||||
|
int compute_capability_minor_;
|
||||||
|
cublasLtHandle_t lt_;
|
||||||
std::unordered_map<int, DeviceStream> streams_;
|
std::unordered_map<int, DeviceStream> streams_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -1,35 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cuComplex.h>
|
|
||||||
#include <cuda_bf16.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
// Maps CPU types to CUDA types.
|
|
||||||
template <typename T>
|
|
||||||
struct CTypeToCudaType {
|
|
||||||
using type = T;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct CTypeToCudaType<float16_t> {
|
|
||||||
using type = __half;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct CTypeToCudaType<bfloat16_t> {
|
|
||||||
using type = __nv_bfloat16;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct CTypeToCudaType<complex64_t> {
|
|
||||||
using type = cuComplex;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
using cuda_type_t = typename CTypeToCudaType<T>::type;
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/event.h"
|
#include "mlx/backend/cuda/event.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
@@ -111,12 +112,12 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
|||||||
|
|
||||||
SharedEvent::SharedEvent() {
|
SharedEvent::SharedEvent() {
|
||||||
// Allocate cuda::atomic on managed memory.
|
// Allocate cuda::atomic on managed memory.
|
||||||
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic));
|
Atomic* ac;
|
||||||
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr());
|
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
|
||||||
new (ac) Atomic(0);
|
new (ac) Atomic(0);
|
||||||
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) {
|
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
|
||||||
ptr->~Atomic();
|
ptr->~Atomic();
|
||||||
allocator::free(buffer);
|
allocator().cuda_free(ptr);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,7 +156,10 @@ void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
|||||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
||||||
if (s.device == mlx::core::Device::cpu) {
|
if (s.device == mlx::core::Device::cpu) {
|
||||||
scheduler::enqueue(s, [*this, value]() mutable { signal(value); });
|
// Signal through a GPU stream so the atomic is updated in GPU - updating
|
||||||
|
// the atomic in CPU sometimes does not get GPU notified.
|
||||||
|
static CudaStream stream(device(mlx::core::Device::gpu));
|
||||||
|
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
|
||||||
} else {
|
} else {
|
||||||
auto& encoder = get_command_encoder(s);
|
auto& encoder = get_command_encoder(s);
|
||||||
encoder.launch_kernel(
|
encoder.launch_kernel(
|
||||||
|
|||||||
29
mlx/backend/cuda/fence.cpp
Normal file
29
mlx/backend/cuda/fence.cpp
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/fence.h"
|
||||||
|
#include "mlx/backend/cuda/event.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
struct FenceImpl {
|
||||||
|
uint32_t count;
|
||||||
|
cu::SharedEvent event;
|
||||||
|
};
|
||||||
|
|
||||||
|
Fence::Fence(Stream s) {
|
||||||
|
fence_ = std::shared_ptr<void>(
|
||||||
|
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
|
||||||
|
}
|
||||||
|
|
||||||
|
void Fence::wait(Stream s, const array&) {
|
||||||
|
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||||
|
fence->event.wait(fence->count);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Fence::update(Stream s, const array&) {
|
||||||
|
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||||
|
fence->count++;
|
||||||
|
fence->event.signal(s, fence->count);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/event.h"
|
|
||||||
#include "mlx/fence.h"
|
|
||||||
#include "mlx/scheduler.h"
|
|
||||||
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
__host__ __device__ void busy_wait(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
|
||||||
while (true) {
|
|
||||||
// In theory the atomic_thread_fence is not needed, but for CUDA 11 without
|
|
||||||
// it the load() may never return new value.
|
|
||||||
cuda::atomic_thread_fence(cuda::memory_order_seq_cst);
|
|
||||||
uint64_t current = ac->load();
|
|
||||||
if (current >= value) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void busy_wait_kernel(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
|
||||||
busy_wait(ac, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
struct FenceImpl {
|
|
||||||
uint32_t count;
|
|
||||||
cu::SharedEvent event;
|
|
||||||
};
|
|
||||||
|
|
||||||
Fence::Fence(Stream s) {
|
|
||||||
fence_ = std::shared_ptr<void>(
|
|
||||||
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
|
|
||||||
}
|
|
||||||
|
|
||||||
void Fence::wait(Stream s, const array&) {
|
|
||||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
|
||||||
// We can't use SharedEvent::wait because it could hang in CUDA 11, see also:
|
|
||||||
// https://github.com/ml-explore/mlx/issues/2137
|
|
||||||
const auto& ac = fence->event.atomic();
|
|
||||||
if (s.device == mlx::core::Device::cpu) {
|
|
||||||
scheduler::enqueue(s, [ac, count = fence->count]() {
|
|
||||||
nvtx3::scoped_range r("Fence::wait()");
|
|
||||||
busy_wait(ac.get(), count);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
nvtx3::scoped_range r("Fence::wait(s)");
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.launch_kernel(
|
|
||||||
encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) {
|
|
||||||
busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count);
|
|
||||||
});
|
|
||||||
encoder.add_completed_handler([ac]() {});
|
|
||||||
encoder.end_encoding();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Fence::update(Stream s, const array&) {
|
|
||||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
|
||||||
fence->count++;
|
|
||||||
fence->event.signal(s, fence->count);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
121
mlx/backend/cuda/iterators/general_iterator.cuh
Normal file
121
mlx/backend/cuda/iterators/general_iterator.cuh
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <thrust/iterator/iterator_adaptor.h>
|
||||||
|
#include <cuda/std/utility>
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
// Iterating non-contiguous array.
|
||||||
|
template <typename Iterator, typename IdxT = int64_t>
|
||||||
|
class general_iterator
|
||||||
|
: public thrust::
|
||||||
|
iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator> {
|
||||||
|
public:
|
||||||
|
using super_t =
|
||||||
|
thrust::iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator>;
|
||||||
|
|
||||||
|
using reference = typename super_t::reference;
|
||||||
|
using difference_type = typename super_t::difference_type;
|
||||||
|
|
||||||
|
__host__ __device__ general_iterator(
|
||||||
|
Iterator it,
|
||||||
|
IdxT index,
|
||||||
|
int ndim,
|
||||||
|
Shape shape,
|
||||||
|
Strides strides)
|
||||||
|
: super_t(it),
|
||||||
|
index_(index),
|
||||||
|
ndim_(ndim),
|
||||||
|
shape_(cuda::std::move(shape)),
|
||||||
|
strides_(cuda::std::move(strides)) {}
|
||||||
|
|
||||||
|
__host__ __device__ IdxT index() const {
|
||||||
|
return index_;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ const Shape& shape() const {
|
||||||
|
return shape_;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ const Strides& strides() const {
|
||||||
|
return strides_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class thrust::iterator_core_access;
|
||||||
|
|
||||||
|
__host__ __device__ bool equal(const general_iterator& other) const {
|
||||||
|
return this->base() == other.base() && this->index() == other.index();
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ void advance(difference_type n) {
|
||||||
|
this->index_ += n;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ void increment() {
|
||||||
|
this->index_ += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ void decrement() {
|
||||||
|
this->index_ -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ difference_type
|
||||||
|
distance_to(const general_iterator& other) const {
|
||||||
|
_CCCL_ASSERT(
|
||||||
|
this->base() == other.base(),
|
||||||
|
"Underlying iterator must point to same base iterator");
|
||||||
|
return other.index() - this->index();
|
||||||
|
}
|
||||||
|
|
||||||
|
// The dereference is device-only to avoid accidental running in host.
|
||||||
|
__device__ typename super_t::reference dereference() const {
|
||||||
|
IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_);
|
||||||
|
return *(this->base() + offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
IdxT index_;
|
||||||
|
int ndim_;
|
||||||
|
Shape shape_;
|
||||||
|
Strides strides_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename IdxT, typename Iterator>
|
||||||
|
__host__ __device__ auto make_general_iterator(
|
||||||
|
Iterator it,
|
||||||
|
IdxT index,
|
||||||
|
int ndim,
|
||||||
|
Shape shape,
|
||||||
|
Strides strides) {
|
||||||
|
return general_iterator<Iterator, IdxT>(
|
||||||
|
it, index, ndim, cuda::std::move(shape), cuda::std::move(strides));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IdxT, typename Iterator>
|
||||||
|
auto make_general_iterator(
|
||||||
|
Iterator it,
|
||||||
|
const std::vector<int32_t>& shape,
|
||||||
|
const std::vector<int64_t>& strides) {
|
||||||
|
return make_general_iterator<IdxT>(
|
||||||
|
it, 0, shape.size(), const_param(shape), const_param(strides));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IdxT, typename Iterator>
|
||||||
|
auto make_general_iterators(
|
||||||
|
Iterator it,
|
||||||
|
IdxT size,
|
||||||
|
const std::vector<int32_t>& shape,
|
||||||
|
const std::vector<int64_t>& strides) {
|
||||||
|
auto ndim = shape.size();
|
||||||
|
auto shape_arg = const_param(shape);
|
||||||
|
auto strides_arg = const_param(strides);
|
||||||
|
return std::make_pair(
|
||||||
|
make_general_iterator<IdxT>(it, 0, ndim, shape_arg, strides_arg),
|
||||||
|
make_general_iterator<IdxT>(it, size, ndim, shape_arg, strides_arg));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
26
mlx/backend/cuda/kernel_utils.cu
Normal file
26
mlx/backend/cuda/kernel_utils.cu
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2) {
|
||||||
|
Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2);
|
||||||
|
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||||
|
}
|
||||||
|
|
||||||
|
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) {
|
||||||
|
Dims dims = get_2d_grid_dims_common(shape, strides);
|
||||||
|
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||||
|
}
|
||||||
|
|
||||||
|
dim3 get_2d_grid_dims(
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
|
size_t divisor) {
|
||||||
|
Dims dims = get_2d_grid_dims_common(shape, strides, divisor);
|
||||||
|
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
131
mlx/backend/cuda/kernel_utils.cuh
Normal file
131
mlx/backend/cuda/kernel_utils.cuh
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
// This file includes host-only utilies for writing CUDA kernels, the difference
|
||||||
|
// from backend/cuda/kernels/utils.cuh is that the latter file only include
|
||||||
|
// device-only code.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||||
|
|
||||||
|
#include <cuComplex.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <cuda/cmath>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Convert a number between 1~3 to constexpr.
|
||||||
|
#define MLX_SWITCH_1_2_3(N, NDIM, ...) \
|
||||||
|
switch (N) { \
|
||||||
|
case 1: { \
|
||||||
|
constexpr int NDIM = 1; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case 2: { \
|
||||||
|
constexpr int NDIM = 2; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case 3: { \
|
||||||
|
constexpr int NDIM = 3; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
// Like MLX_SWITCH_ALL_TYPES but for booleans.
|
||||||
|
#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \
|
||||||
|
if (BOOL) { \
|
||||||
|
constexpr bool BOOL_ALIAS = true; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else { \
|
||||||
|
constexpr bool BOOL_ALIAS = false; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
}
|
||||||
|
|
||||||
|
// Maps CPU types to CUDA types.
|
||||||
|
template <typename T>
|
||||||
|
struct CTypeToCudaType {
|
||||||
|
using type = T;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct CTypeToCudaType<float16_t> {
|
||||||
|
using type = __half;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct CTypeToCudaType<bfloat16_t> {
|
||||||
|
using type = __nv_bfloat16;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct CTypeToCudaType<complex64_t> {
|
||||||
|
using type = cuComplex;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using cuda_type_t = typename CTypeToCudaType<T>::type;
|
||||||
|
|
||||||
|
// Type traits for detecting floating numbers.
|
||||||
|
template <typename T>
|
||||||
|
inline constexpr bool is_floating_v =
|
||||||
|
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
||||||
|
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
||||||
|
|
||||||
|
// Utility to copy data from vector to array in host.
|
||||||
|
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||||
|
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
||||||
|
if (vec.size() > NDIM) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||||
|
}
|
||||||
|
cuda::std::array<T, NDIM> result;
|
||||||
|
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the grid and block dimensions, check backend/common/utils.h for docs.
|
||||||
|
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||||
|
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);
|
||||||
|
dim3 get_2d_grid_dims(
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
|
size_t divisor);
|
||||||
|
|
||||||
|
// Return a block size that achieves maximum potential occupancy for kernel.
|
||||||
|
template <typename T>
|
||||||
|
inline uint max_occupancy_block_dim(T kernel) {
|
||||||
|
int _, block_dim;
|
||||||
|
CHECK_CUDA_ERROR(cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
|
||||||
|
return block_dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
|
||||||
|
// assuming each thread handles |work_per_thread| elements of |arr|.
|
||||||
|
template <typename T>
|
||||||
|
inline std::tuple<dim3, uint> get_launch_args(
|
||||||
|
T kernel,
|
||||||
|
const array& arr,
|
||||||
|
bool large,
|
||||||
|
int work_per_thread = 1) {
|
||||||
|
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
|
||||||
|
uint block_dim = max_occupancy_block_dim(kernel);
|
||||||
|
if (block_dim > nthreads) {
|
||||||
|
block_dim = nthreads;
|
||||||
|
}
|
||||||
|
dim3 num_blocks;
|
||||||
|
if (large) {
|
||||||
|
num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread);
|
||||||
|
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
|
||||||
|
} else {
|
||||||
|
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
|
||||||
|
}
|
||||||
|
return std::make_tuple(num_blocks, block_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
278
mlx/backend/cuda/kernels/binary_ops.cuh
Normal file
278
mlx/backend/cuda/kernels/binary_ops.cuh
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||||
|
|
||||||
|
#include <cuComplex.h>
|
||||||
|
#include <cuda/std/array>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
struct Add {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
return x + y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FloorDivide {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
|
return x / y;
|
||||||
|
} else {
|
||||||
|
return trunc(x / y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Divide {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
return x / y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Remainder {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
|
if constexpr (cuda::std::is_signed_v<T>) {
|
||||||
|
auto r = x % y;
|
||||||
|
if (r != 0 && (r < 0 != y < 0)) {
|
||||||
|
r += y;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
} else {
|
||||||
|
return x % y;
|
||||||
|
}
|
||||||
|
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
return x % y;
|
||||||
|
} else {
|
||||||
|
T r = fmod(x, y);
|
||||||
|
if (r != 0 && (r < 0 != y < 0)) {
|
||||||
|
r = r + y;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Equal {
|
||||||
|
template <typename T>
|
||||||
|
__device__ bool operator()(T x, T y) {
|
||||||
|
return x == y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NaNEqual {
|
||||||
|
template <typename T>
|
||||||
|
__device__ bool operator()(T x, T y) {
|
||||||
|
if constexpr (std::is_same_v<T, cuComplex>) {
|
||||||
|
return x == y ||
|
||||||
|
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) &&
|
||||||
|
isnan(cuCimagf(y))) ||
|
||||||
|
(cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) &&
|
||||||
|
isnan(cuCimagf(y))) ||
|
||||||
|
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) &&
|
||||||
|
cuCimagf(x) == cuCimagf(y));
|
||||||
|
} else {
|
||||||
|
return x == y || (isnan(x) && isnan(y));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Greater {
|
||||||
|
template <typename T>
|
||||||
|
__device__ bool operator()(T x, T y) {
|
||||||
|
return x > y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct GreaterEqual {
|
||||||
|
template <typename T>
|
||||||
|
__device__ bool operator()(T x, T y) {
|
||||||
|
return x >= y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Less {
|
||||||
|
template <typename T>
|
||||||
|
__device__ bool operator()(T x, T y) {
|
||||||
|
return x < y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LessEqual {
|
||||||
|
template <typename T>
|
||||||
|
__device__ bool operator()(T x, T y) {
|
||||||
|
return x <= y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogAddExp {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
if (isnan(x) || isnan(y)) {
|
||||||
|
return cuda::std::numeric_limits<T>::quiet_NaN();
|
||||||
|
}
|
||||||
|
T maxval = max(x, y);
|
||||||
|
T minval = min(x, y);
|
||||||
|
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
|
||||||
|
maxval == cuda::std::numeric_limits<T>::infinity())
|
||||||
|
? maxval
|
||||||
|
: T(float(maxval) + log1p(expf(minval - maxval)));
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Maximum {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
|
return max(x, y);
|
||||||
|
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return x > y ? x : y;
|
||||||
|
} else {
|
||||||
|
if (isnan(x)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return x > y ? x : y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Minimum {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
|
return min(x, y);
|
||||||
|
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return x < y ? x : y;
|
||||||
|
} else {
|
||||||
|
if (isnan(x)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return x < y ? x : y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Multiply {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
return x * y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NotEqual {
|
||||||
|
template <typename T>
|
||||||
|
__device__ bool operator()(T x, T y) {
|
||||||
|
if constexpr (std::is_same_v<T, cuComplex>) {
|
||||||
|
return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y);
|
||||||
|
} else {
|
||||||
|
return x != y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Power {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T base, T exp) {
|
||||||
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
|
T res = 1;
|
||||||
|
while (exp) {
|
||||||
|
if (exp & 1) {
|
||||||
|
res *= base;
|
||||||
|
}
|
||||||
|
exp >>= 1;
|
||||||
|
base *= base;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
auto x_theta = atan2f(base.y, base.x);
|
||||||
|
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
|
||||||
|
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
|
||||||
|
auto phase = exp.y * x_ln_r + exp.x * x_theta;
|
||||||
|
return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase));
|
||||||
|
} else {
|
||||||
|
return powf(base, exp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Subtract {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
return x - y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogicalAnd {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
return x && y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogicalOr {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
return x || y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BitwiseAnd {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
return x & y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BitwiseOr {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
return x | y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BitwiseXor {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
return x ^ y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LeftShift {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
return x << y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RightShift {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x, T y) {
|
||||||
|
return x >> y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcTan2 {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T y, T x) {
|
||||||
|
return atan2f(y, x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DivMod {
|
||||||
|
template <typename T>
|
||||||
|
__device__ cuda::std::array<T, 2> operator()(T x, T y) {
|
||||||
|
return {FloorDivide{}(x, y), Remainder{}(x, y)};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
59
mlx/backend/cuda/kernels/cast_op.cuh
Normal file
59
mlx/backend/cuda/kernels/cast_op.cuh
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuComplex.h>
|
||||||
|
#include <thrust/iterator/transform_iterator.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
// An op that does static_cast, with custom conversions for some types.
|
||||||
|
template <typename SrcT, typename DstT, typename = void>
|
||||||
|
struct CastOp {
|
||||||
|
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, DstT>;
|
||||||
|
|
||||||
|
__device__ DstT operator()(SrcT x) {
|
||||||
|
return static_cast<DstT>(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Converting a complex number to real number discards the imaginary part.
|
||||||
|
template <typename DstT>
|
||||||
|
struct CastOp<
|
||||||
|
cuComplex,
|
||||||
|
DstT,
|
||||||
|
cuda::std::enable_if_t<!cuda::std::is_same_v<cuComplex, DstT>>> {
|
||||||
|
static constexpr bool is_castable = cuda::std::is_convertible_v<float, DstT>;
|
||||||
|
|
||||||
|
__device__ DstT operator()(cuComplex x) {
|
||||||
|
static_assert(!cuda::std::is_same_v<cuComplex, DstT>);
|
||||||
|
return static_cast<DstT>(cuCrealf(x));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Allow converting a real number to complex number.
|
||||||
|
template <typename SrcT>
|
||||||
|
struct CastOp<
|
||||||
|
SrcT,
|
||||||
|
cuComplex,
|
||||||
|
cuda::std::enable_if_t<!cuda::std::is_same_v<SrcT, cuComplex>>> {
|
||||||
|
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, float>;
|
||||||
|
|
||||||
|
__device__ cuComplex operator()(SrcT x) {
|
||||||
|
static_assert(!cuda::std::is_same_v<SrcT, cuComplex>);
|
||||||
|
return cuComplex{static_cast<float>(x), 0};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Return an iterator that cast the value to DstT using CastOp.
|
||||||
|
template <typename DstT, typename Iterator>
|
||||||
|
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||||
|
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
|
||||||
|
if constexpr (std::is_same_v<SrcT, DstT>) {
|
||||||
|
return it;
|
||||||
|
} else {
|
||||||
|
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
240
mlx/backend/cuda/kernels/cucomplex_math.cuh
Normal file
240
mlx/backend/cuda/kernels/cucomplex_math.cuh
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
// Copyright © 2017-2024 The Simons Foundation, Inc.
|
||||||
|
//
|
||||||
|
// FINUFFT is licensed under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance with the
|
||||||
|
// License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// Forked from
|
||||||
|
// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuComplex.h>
|
||||||
|
|
||||||
|
// This header provides some helper functions for cuComplex types.
|
||||||
|
// It mainly wraps existing CUDA implementations to provide operator overloads
|
||||||
|
// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are
|
||||||
|
// all provided by CUDA
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||||
|
operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||||
|
return cuCadd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||||
|
operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||||
|
return cuCsub(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||||
|
operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||||
|
return cuCmul(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||||
|
operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||||
|
return cuCdiv(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||||
|
operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||||
|
double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b));
|
||||||
|
double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b));
|
||||||
|
return make_cuDoubleComplex(r, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ bool operator==(
|
||||||
|
const cuDoubleComplex& a,
|
||||||
|
const cuDoubleComplex& b) {
|
||||||
|
return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ bool operator!=(
|
||||||
|
const cuDoubleComplex& a,
|
||||||
|
const cuDoubleComplex& b) {
|
||||||
|
return !(a == b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ bool operator>(
|
||||||
|
const cuDoubleComplex& a,
|
||||||
|
const cuDoubleComplex& b) {
|
||||||
|
double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a));
|
||||||
|
double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b));
|
||||||
|
return mag_a > mag_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ bool operator>=(
|
||||||
|
const cuDoubleComplex& a,
|
||||||
|
const cuDoubleComplex& b) {
|
||||||
|
return a > b || a == b;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ bool operator<(
|
||||||
|
const cuDoubleComplex& a,
|
||||||
|
const cuDoubleComplex& b) {
|
||||||
|
return b > a;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ bool operator<=(
|
||||||
|
const cuDoubleComplex& a,
|
||||||
|
const cuDoubleComplex& b) {
|
||||||
|
return b > a || a == b;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||||
|
operator+(const cuDoubleComplex& a, double b) {
|
||||||
|
return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a));
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||||
|
operator+(double a, const cuDoubleComplex& b) {
|
||||||
|
return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||||
|
operator-(const cuDoubleComplex& a, double b) {
|
||||||
|
return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a));
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||||
|
operator-(double a, const cuDoubleComplex& b) {
|
||||||
|
return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||||
|
operator*(const cuDoubleComplex& a, double b) {
|
||||||
|
return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||||
|
operator*(double a, const cuDoubleComplex& b) {
|
||||||
|
return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||||
|
operator/(const cuDoubleComplex& a, double b) {
|
||||||
|
return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||||
|
operator/(double a, const cuDoubleComplex& b) {
|
||||||
|
double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b);
|
||||||
|
return make_cuDoubleComplex(
|
||||||
|
(a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuFloatComplex
|
||||||
|
operator+(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||||
|
return cuCaddf(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuFloatComplex
|
||||||
|
operator-(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||||
|
return cuCsubf(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuFloatComplex
|
||||||
|
operator*(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||||
|
return cuCmulf(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuFloatComplex
|
||||||
|
operator/(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||||
|
return cuCdivf(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuFloatComplex
|
||||||
|
operator%(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||||
|
float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b));
|
||||||
|
float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b));
|
||||||
|
return make_cuFloatComplex(r, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ bool operator==(
|
||||||
|
const cuFloatComplex& a,
|
||||||
|
const cuFloatComplex& b) {
|
||||||
|
return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ bool operator!=(
|
||||||
|
const cuFloatComplex& a,
|
||||||
|
const cuFloatComplex& b) {
|
||||||
|
return !(a == b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ bool operator>(
|
||||||
|
const cuFloatComplex& a,
|
||||||
|
const cuFloatComplex& b) {
|
||||||
|
float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a));
|
||||||
|
float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b));
|
||||||
|
return mag_a > mag_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ bool operator>=(
|
||||||
|
const cuFloatComplex& a,
|
||||||
|
const cuFloatComplex& b) {
|
||||||
|
return a > b || a == b;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ bool operator<(
|
||||||
|
const cuFloatComplex& a,
|
||||||
|
const cuFloatComplex& b) {
|
||||||
|
return b > a;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ bool operator<=(
|
||||||
|
const cuFloatComplex& a,
|
||||||
|
const cuFloatComplex& b) {
|
||||||
|
return b > a || a == b;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuFloatComplex
|
||||||
|
operator+(const cuFloatComplex& a, float b) {
|
||||||
|
return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a));
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuFloatComplex
|
||||||
|
operator+(float a, const cuFloatComplex& b) {
|
||||||
|
return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuFloatComplex
|
||||||
|
operator-(const cuFloatComplex& a, float b) {
|
||||||
|
return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a));
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuFloatComplex
|
||||||
|
operator-(float a, const cuFloatComplex& b) {
|
||||||
|
return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuFloatComplex
|
||||||
|
operator*(const cuFloatComplex& a, float b) {
|
||||||
|
return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuFloatComplex
|
||||||
|
operator*(float a, const cuFloatComplex& b) {
|
||||||
|
return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuFloatComplex
|
||||||
|
operator/(const cuFloatComplex& a, float b) {
|
||||||
|
return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __host__ __device__ cuFloatComplex
|
||||||
|
operator/(float a, const cuFloatComplex& b) {
|
||||||
|
float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b);
|
||||||
|
return make_cuFloatComplex(
|
||||||
|
(a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom);
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda/std/limits>
|
#include <cuda/std/limits>
|
||||||
#include <cuda/std/type_traits>
|
#include <cuda/std/type_traits>
|
||||||
@@ -9,36 +10,122 @@
|
|||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Missing C++ operator overrides for CUDA 7.
|
// Unary ops for half types.
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||||
|
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
|
||||||
|
template <typename T> \
|
||||||
|
__forceinline__ __device__ auto NAME(T x) { \
|
||||||
|
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||||
|
return HALF_OP(x); \
|
||||||
|
} else { \
|
||||||
|
return ::NAME(x); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
|
||||||
|
template <typename T> \
|
||||||
|
__forceinline__ __device__ auto NAME(T x) { \
|
||||||
|
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||||
|
return HALF_OP(x); \
|
||||||
|
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||||
|
return HALF_OP(x); \
|
||||||
|
} else { \
|
||||||
|
return ::NAME(x); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#define MLX_DEFINE_BF16_OP(OP) \
|
#define MLX_DEFINE_UNARY_OP_FALLBCK(NAME) \
|
||||||
__forceinline__ __device__ __nv_bfloat16 operator OP( \
|
template <typename T> \
|
||||||
__nv_bfloat16 x, __nv_bfloat16 y) { \
|
__forceinline__ __device__ auto NAME(T x) { \
|
||||||
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
|
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||||
|
return ::NAME(__half2float(x)); \
|
||||||
|
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||||
|
return ::NAME(__bfloat162float(x)); \
|
||||||
|
} else { \
|
||||||
|
return ::NAME(x); \
|
||||||
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MLX_DEFINE_BF16_CMP(OP) \
|
MLX_DEFINE_UNARY_OP(abs, __habs)
|
||||||
__forceinline__ __device__ bool operator OP( \
|
MLX_DEFINE_UNARY_OP(ceil, hceil)
|
||||||
__nv_bfloat16 x, __nv_bfloat16 y) { \
|
MLX_DEFINE_UNARY_OP(cos, hcos)
|
||||||
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
|
MLX_DEFINE_UNARY_OP(exp, hexp)
|
||||||
|
MLX_DEFINE_UNARY_OP(floor, hfloor)
|
||||||
|
MLX_DEFINE_UNARY_OP(isnan, __hisnan)
|
||||||
|
MLX_DEFINE_UNARY_OP(log, hlog)
|
||||||
|
MLX_DEFINE_UNARY_OP(log2, hlog2)
|
||||||
|
MLX_DEFINE_UNARY_OP(log10, hlog10)
|
||||||
|
MLX_DEFINE_UNARY_OP(rint, hrint)
|
||||||
|
MLX_DEFINE_UNARY_OP(rsqrt, hrsqrt)
|
||||||
|
MLX_DEFINE_UNARY_OP(sin, hsin)
|
||||||
|
MLX_DEFINE_UNARY_OP(sqrt, hsqrt)
|
||||||
|
MLX_DEFINE_UNARY_OP_FALLBCK(acos)
|
||||||
|
MLX_DEFINE_UNARY_OP_FALLBCK(acosh)
|
||||||
|
MLX_DEFINE_UNARY_OP_FALLBCK(asin)
|
||||||
|
MLX_DEFINE_UNARY_OP_FALLBCK(asinh)
|
||||||
|
MLX_DEFINE_UNARY_OP_FALLBCK(atan)
|
||||||
|
MLX_DEFINE_UNARY_OP_FALLBCK(atanh)
|
||||||
|
MLX_DEFINE_UNARY_OP_FALLBCK(cosh)
|
||||||
|
MLX_DEFINE_UNARY_OP_FALLBCK(log1p)
|
||||||
|
MLX_DEFINE_UNARY_OP_FALLBCK(sinh)
|
||||||
|
MLX_DEFINE_UNARY_OP_FALLBCK(tan)
|
||||||
|
#if __CUDA_ARCH__ >= 1280
|
||||||
|
MLX_DEFINE_UNARY_OP(tanh, htanh)
|
||||||
|
#else
|
||||||
|
MLX_DEFINE_UNARY_OP_FALLBCK(tanh)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#undef MLX_DEFINE_UNARY_OP
|
||||||
|
#undef MLX_DEFINE_UNARY_OP_FALLBCK
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Binary ops for half types.
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||||
|
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
|
||||||
|
template <typename T> \
|
||||||
|
__forceinline__ __device__ auto NAME(T x, T y) { \
|
||||||
|
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||||
|
return HALF_OP(x, y); \
|
||||||
|
} else { \
|
||||||
|
return ::NAME(x, y); \
|
||||||
|
} \
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
|
||||||
|
template <typename T> \
|
||||||
|
__forceinline__ __device__ auto NAME(T x, T y) { \
|
||||||
|
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||||
|
return HALF_OP(x, y); \
|
||||||
|
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||||
|
return HALF_OP(x, y); \
|
||||||
|
} else { \
|
||||||
|
return ::NAME(x, y); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
MLX_DEFINE_BF16_OP(+)
|
MLX_DEFINE_BINARY_OP(max, __hmax)
|
||||||
MLX_DEFINE_BF16_OP(-)
|
MLX_DEFINE_BINARY_OP(min, __hmin)
|
||||||
MLX_DEFINE_BF16_OP(*)
|
|
||||||
MLX_DEFINE_BF16_OP(/)
|
|
||||||
MLX_DEFINE_BF16_CMP(>)
|
|
||||||
MLX_DEFINE_BF16_CMP(<)
|
|
||||||
MLX_DEFINE_BF16_CMP(>=)
|
|
||||||
MLX_DEFINE_BF16_CMP(<=)
|
|
||||||
|
|
||||||
#undef MLX_DEFINE_BF16_OP
|
#undef MLX_DEFINE_BINARY_OP
|
||||||
#undef MLX_DEFINE_BF16_CMP
|
|
||||||
|
|
||||||
#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
template <typename T>
|
||||||
|
__forceinline__ __device__ T fmod(T x, T y) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||||
|
return __float2half(::fmod(__half2float(x), __half2float(y)));
|
||||||
|
#if CUDART_VERSION >= 12000 || __CUDA_ARCH__ >= 800
|
||||||
|
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||||
|
return __float2bfloat16(::fmod(__bfloat162float(x), __bfloat162float(y)));
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
return ::fmod(x, y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Additional C++ operator overrides between half types and native types.
|
// Additional C++ operator overrides between half types and native types.
|
||||||
|
|||||||
349
mlx/backend/cuda/kernels/unary_ops.cuh
Normal file
349
mlx/backend/cuda/kernels/unary_ops.cuh
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
struct Abs {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||||
|
return x;
|
||||||
|
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0};
|
||||||
|
} else {
|
||||||
|
return abs(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcCos {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
return acos(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcCosh {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
return acosh(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcSin {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
return asin(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcSinh {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
return asinh(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcTan {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
return atan(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcTanh {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
return atanh(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BitwiseInvert {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
return ~x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Ceil {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
return ceil(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Conjugate {
|
||||||
|
__device__ cuComplex operator()(cuComplex x) {
|
||||||
|
return {cuCrealf(x), -cuCimagf(x)};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Cos {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
return {
|
||||||
|
cos(cuCrealf(x)) * cosh(cuCimagf(x)),
|
||||||
|
-sin(cuCrealf(x)) * sinh(cuCimagf(x))};
|
||||||
|
} else {
|
||||||
|
return cos(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Cosh {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
return {
|
||||||
|
cosh(cuCrealf(x)) * cos(cuCimagf(x)),
|
||||||
|
sinh(cuCrealf(x)) * sin(cuCimagf(x))};
|
||||||
|
} else {
|
||||||
|
return cosh(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Erf {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||||
|
return erf(__half2float(x));
|
||||||
|
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||||
|
return erf(__bfloat162float(x));
|
||||||
|
} else {
|
||||||
|
return erf(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ErfInv {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||||
|
return erfinv(__half2float(x));
|
||||||
|
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||||
|
return erfinv(__bfloat162float(x));
|
||||||
|
} else {
|
||||||
|
return erfinv(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Exp {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
auto m = exp(cuCrealf(x));
|
||||||
|
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
|
||||||
|
} else {
|
||||||
|
return exp(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Expm1 {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||||
|
return expm1(__half2float(x));
|
||||||
|
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||||
|
return expm1(__bfloat162float(x));
|
||||||
|
} else {
|
||||||
|
return expm1(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Floor {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
return floor(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Imag {
|
||||||
|
__device__ float operator()(cuComplex x) {
|
||||||
|
return cuCimagf(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
return log(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log2 {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
return log2(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log10 {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
return log10(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log1p {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
return log1p(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogicalNot {
|
||||||
|
__device__ bool operator()(bool x) {
|
||||||
|
return !x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Negative {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
return 0 - x;
|
||||||
|
} else {
|
||||||
|
return -x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Real {
|
||||||
|
__device__ float operator()(cuComplex x) {
|
||||||
|
return cuCrealf(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Round {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
return {rint(cuCrealf(x)), rint(cuCimagf(x))};
|
||||||
|
} else {
|
||||||
|
return rint(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Rsqrt {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
return rsqrt(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sigmoid {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
T y = 1 / (1 + exp(-abs(x)));
|
||||||
|
return (x < 0) ? 1 - y : y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sign {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||||
|
return x != 0;
|
||||||
|
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
if (cuCrealf(x) == 0 && cuCimagf(x) == 0) {
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
return x / Abs()(x);
|
||||||
|
}
|
||||||
|
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||||
|
return static_cast<float>((x > T(0.f)) - (x < T(0.f)));
|
||||||
|
} else {
|
||||||
|
return (x > T(0)) - (x < T(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sin {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
return {
|
||||||
|
sin(cuCrealf(x)) * cosh(cuCimagf(x)),
|
||||||
|
cos(cuCrealf(x)) * sinh(cuCimagf(x))};
|
||||||
|
} else {
|
||||||
|
return sin(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sinh {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
return {
|
||||||
|
sinh(cuCrealf(x)) * cos(cuCimagf(x)),
|
||||||
|
cosh(cuCrealf(x)) * sin(cuCimagf(x))};
|
||||||
|
} else {
|
||||||
|
return sinh(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Square {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
return x * x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sqrt {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
return sqrt(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Tan {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
float tan_a = tan(cuCrealf(x));
|
||||||
|
float tanh_b = tanh(cuCimagf(x));
|
||||||
|
float t1 = tan_a * tanh_b;
|
||||||
|
float denom = 1. + t1 * t1;
|
||||||
|
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
|
||||||
|
} else {
|
||||||
|
return tan(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Tanh {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
float tanh_a = tanh(cuCrealf(x));
|
||||||
|
float tan_b = tan(cuCimagf(x));
|
||||||
|
float t1 = tanh_a * tan_b;
|
||||||
|
float denom = 1. + t1 * t1;
|
||||||
|
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
||||||
|
} else {
|
||||||
|
return tanh(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
104
mlx/backend/cuda/kernels/utils.cuh
Normal file
104
mlx/backend/cuda/kernels/utils.cuh
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
// This file must not include any host-only code, utilies that work under both
|
||||||
|
// host and device can be put here.
|
||||||
|
//
|
||||||
|
// See more about the requirements at:
|
||||||
|
// https://docs.nvidia.com/cuda/nvrtc/#language
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuComplex.h>
|
||||||
|
#include <cuda/std/array>
|
||||||
|
#include <cuda/std/limits>
|
||||||
|
#include <cuda/std/tuple>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// CUDA kernel utils
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// To pass shape/strides to kernels via constant memory, their size must be
|
||||||
|
// known at compile time.
|
||||||
|
#define MAX_NDIM 8
|
||||||
|
|
||||||
|
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
||||||
|
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Indexing utils
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename IdxT = int64_t>
|
||||||
|
inline __host__ __device__ IdxT
|
||||||
|
elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||||
|
IdxT loc = 0;
|
||||||
|
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||||
|
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||||
|
elem /= shape[i];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optimize when the ndim is known at compile time.
|
||||||
|
template <int NDIM, typename IdxT = int64_t>
|
||||||
|
inline __host__ __device__ IdxT
|
||||||
|
elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) {
|
||||||
|
IdxT loc = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
|
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||||
|
elem /= shape[i];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int NDIM, typename IdxT = int64_t>
|
||||||
|
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
||||||
|
IdxT elem,
|
||||||
|
const int* shape,
|
||||||
|
const int64_t* a_strides,
|
||||||
|
const int64_t* b_strides) {
|
||||||
|
IdxT a_loc = 0;
|
||||||
|
IdxT b_loc = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
|
int dim_idx = elem % shape[i];
|
||||||
|
a_loc += dim_idx * a_strides[i];
|
||||||
|
b_loc += dim_idx * b_strides[i];
|
||||||
|
elem /= shape[i];
|
||||||
|
}
|
||||||
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optimized version when ndim is larger than 4.
|
||||||
|
template <typename IdxT = int64_t>
|
||||||
|
inline __host__ __device__ IdxT
|
||||||
|
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||||
|
IdxT loc = elem_to_loc_nd<3>(elem, shape, strides);
|
||||||
|
for (int i = ndim - 1; i >= 3; --i) {
|
||||||
|
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||||
|
elem /= shape[i];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IdxT = int64_t>
|
||||||
|
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
||||||
|
IdxT elem,
|
||||||
|
const int* shape,
|
||||||
|
const int64_t* a_strides,
|
||||||
|
const int64_t* b_strides,
|
||||||
|
int ndim) {
|
||||||
|
auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides);
|
||||||
|
for (int i = ndim - 1; i >= 3; --i) {
|
||||||
|
int dim_idx = elem % shape[i];
|
||||||
|
a_loc += dim_idx * a_strides[i];
|
||||||
|
b_loc += dim_idx * b_strides[i];
|
||||||
|
elem /= shape[i];
|
||||||
|
}
|
||||||
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
474
mlx/backend/cuda/matmul.cpp
Normal file
474
mlx/backend/cuda/matmul.cpp
Normal file
@@ -0,0 +1,474 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/matmul.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cublasLt.h>
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||||
|
|
||||||
|
void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||||
|
if (err != CUBLAS_STATUS_SUCCESS) {
|
||||||
|
// TODO: Use cublasGetStatusString when it is widely available.
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class MatMul {
|
||||||
|
public:
|
||||||
|
MatMul(
|
||||||
|
Device& device,
|
||||||
|
Dtype dtype,
|
||||||
|
bool a_transposed,
|
||||||
|
uint64_t a_rows,
|
||||||
|
uint64_t a_cols,
|
||||||
|
int64_t lda,
|
||||||
|
bool b_transposed,
|
||||||
|
uint64_t b_rows,
|
||||||
|
uint64_t b_cols,
|
||||||
|
int64_t ldb,
|
||||||
|
int32_t batch_count,
|
||||||
|
int64_t a_batch_stride,
|
||||||
|
int64_t b_batch_stride) {
|
||||||
|
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||||
|
|
||||||
|
auto type = dtype_to_cuda_type(dtype);
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
||||||
|
&matmul_desc_, dtype_to_compute_type(dtype), type));
|
||||||
|
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
|
matmul_desc_,
|
||||||
|
CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
||||||
|
&pointer_mode,
|
||||||
|
sizeof(int32_t)));
|
||||||
|
cublasOperation_t op = CUBLAS_OP_N;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
|
matmul_desc_,
|
||||||
|
CUBLASLT_MATMUL_DESC_TRANSA,
|
||||||
|
&op,
|
||||||
|
sizeof(cublasOperation_t)));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
|
matmul_desc_,
|
||||||
|
CUBLASLT_MATMUL_DESC_TRANSB,
|
||||||
|
&op,
|
||||||
|
sizeof(cublasOperation_t)));
|
||||||
|
|
||||||
|
a_desc_ = create_matrix_layout(
|
||||||
|
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
||||||
|
b_desc_ = create_matrix_layout(
|
||||||
|
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
||||||
|
out_desc_ = create_matrix_layout(
|
||||||
|
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||||
|
|
||||||
|
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||||
|
// for Hopper+:
|
||||||
|
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||||
|
uint64_t MiB = 1024 * 1024;
|
||||||
|
uint64_t workspace_size =
|
||||||
|
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
||||||
|
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
||||||
|
pref_,
|
||||||
|
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||||
|
&workspace_size,
|
||||||
|
sizeof(uint64_t)));
|
||||||
|
}
|
||||||
|
|
||||||
|
MatMul(
|
||||||
|
Device& device,
|
||||||
|
Dtype dtype,
|
||||||
|
bool a_transposed,
|
||||||
|
uint64_t a_rows,
|
||||||
|
uint64_t a_cols,
|
||||||
|
int64_t lda,
|
||||||
|
bool b_transposed,
|
||||||
|
uint64_t b_rows,
|
||||||
|
uint64_t b_cols,
|
||||||
|
int64_t ldb,
|
||||||
|
bool c_transposed,
|
||||||
|
int64_t ldc,
|
||||||
|
int32_t batch_count,
|
||||||
|
int64_t a_batch_stride,
|
||||||
|
int64_t b_batch_stride,
|
||||||
|
int64_t c_batch_stride)
|
||||||
|
: MatMul(
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
a_transposed,
|
||||||
|
a_rows,
|
||||||
|
a_cols,
|
||||||
|
lda,
|
||||||
|
b_transposed,
|
||||||
|
b_rows,
|
||||||
|
b_cols,
|
||||||
|
ldb,
|
||||||
|
batch_count,
|
||||||
|
a_batch_stride,
|
||||||
|
b_batch_stride) {
|
||||||
|
auto type = dtype_to_cuda_type(dtype);
|
||||||
|
c_desc_ = create_matrix_layout(
|
||||||
|
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
~MatMul() {
|
||||||
|
cublasLtMatrixLayoutDestroy(a_desc_);
|
||||||
|
cublasLtMatrixLayoutDestroy(b_desc_);
|
||||||
|
cublasLtMatrixLayoutDestroy(c_desc_);
|
||||||
|
cublasLtMatrixLayoutDestroy(out_desc_);
|
||||||
|
cublasLtMatmulDescDestroy(matmul_desc_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
void* out,
|
||||||
|
void* a,
|
||||||
|
void* b,
|
||||||
|
void* c = nullptr,
|
||||||
|
float alpha = 1,
|
||||||
|
float beta = 0) {
|
||||||
|
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
|
||||||
|
int ret = 0;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
|
||||||
|
encoder.device().lt_handle(),
|
||||||
|
matmul_desc_,
|
||||||
|
a_desc_,
|
||||||
|
b_desc_,
|
||||||
|
out_desc_,
|
||||||
|
out_desc_,
|
||||||
|
pref_,
|
||||||
|
1,
|
||||||
|
&heuristic_,
|
||||||
|
&ret));
|
||||||
|
if (ret == 0) {
|
||||||
|
throw std::runtime_error("Can not find algorithm for matmul.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array workspace(
|
||||||
|
allocator::malloc(heuristic_.workspaceSize),
|
||||||
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
|
int8);
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||||
|
encoder.device().lt_handle(),
|
||||||
|
matmul_desc_,
|
||||||
|
&alpha,
|
||||||
|
a,
|
||||||
|
a_desc_,
|
||||||
|
b,
|
||||||
|
b_desc_,
|
||||||
|
&beta,
|
||||||
|
c ? c : out,
|
||||||
|
c ? c_desc_ : out_desc_,
|
||||||
|
out,
|
||||||
|
out_desc_,
|
||||||
|
&heuristic_.algo,
|
||||||
|
workspace.data<void>(),
|
||||||
|
workspace.nbytes(),
|
||||||
|
stream));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case uint8:
|
||||||
|
case uint16:
|
||||||
|
case int8:
|
||||||
|
case int16:
|
||||||
|
case int32:
|
||||||
|
return CUBLAS_COMPUTE_32I;
|
||||||
|
case float16:
|
||||||
|
case bfloat16:
|
||||||
|
return CUBLAS_COMPUTE_16F;
|
||||||
|
case float32:
|
||||||
|
return CUBLAS_COMPUTE_32F;
|
||||||
|
case float64:
|
||||||
|
case complex64:
|
||||||
|
return CUBLAS_COMPUTE_64F;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case uint8:
|
||||||
|
return CUDA_R_8U;
|
||||||
|
case uint16:
|
||||||
|
return CUDA_R_16U;
|
||||||
|
case int8:
|
||||||
|
return CUDA_R_8I;
|
||||||
|
case int16:
|
||||||
|
return CUDA_R_16I;
|
||||||
|
case int32:
|
||||||
|
return CUDA_R_32I;
|
||||||
|
case float16:
|
||||||
|
return CUDA_R_16F;
|
||||||
|
case bfloat16:
|
||||||
|
return CUDA_R_16BF;
|
||||||
|
case float32:
|
||||||
|
return CUDA_R_32F;
|
||||||
|
case float64:
|
||||||
|
return CUDA_R_64F;
|
||||||
|
case complex64:
|
||||||
|
return CUDA_C_32F;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cublasLtMatrixLayout_t create_matrix_layout(
|
||||||
|
cudaDataType_t type,
|
||||||
|
uint64_t rows,
|
||||||
|
uint64_t cols,
|
||||||
|
bool transposed,
|
||||||
|
int64_t ld,
|
||||||
|
int32_t batch_count,
|
||||||
|
int64_t batch_stride) {
|
||||||
|
cublasLtMatrixLayout_t desc;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
|
||||||
|
cublasLtOrder_t order =
|
||||||
|
transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
|
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
|
||||||
|
if (batch_count > 1) {
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
|
desc,
|
||||||
|
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||||
|
&batch_count,
|
||||||
|
sizeof(int32_t)));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
|
desc,
|
||||||
|
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||||
|
&batch_stride,
|
||||||
|
sizeof(int64_t)));
|
||||||
|
}
|
||||||
|
return desc;
|
||||||
|
}
|
||||||
|
|
||||||
|
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
||||||
|
cublasLtMatmulPreference_t pref_{nullptr};
|
||||||
|
cublasLtMatrixLayout_t a_desc_{nullptr};
|
||||||
|
cublasLtMatrixLayout_t b_desc_{nullptr};
|
||||||
|
cublasLtMatrixLayout_t c_desc_{nullptr};
|
||||||
|
cublasLtMatrixLayout_t out_desc_{nullptr};
|
||||||
|
cublasLtMatmulHeuristicResult_t heuristic_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::tuple<bool, int64_t, array>
|
||||||
|
check_transpose(std::vector<array>& copies, const Stream& s, const array& arr) {
|
||||||
|
auto stx = arr.strides()[arr.ndim() - 2];
|
||||||
|
auto sty = arr.strides()[arr.ndim() - 1];
|
||||||
|
if (sty == 1 && stx == arr.shape(-1)) {
|
||||||
|
return std::make_tuple(false, stx, arr);
|
||||||
|
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||||
|
return std::make_tuple(true, sty, arr);
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||||
|
copies.push_back(arr_copy);
|
||||||
|
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Matmul::eval_gpu");
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a_pre = inputs[0];
|
||||||
|
auto& b_pre = inputs[1];
|
||||||
|
// Return 0s if either input is empty.
|
||||||
|
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||||
|
array zero(0, a_pre.dtype());
|
||||||
|
encoder.add_temporary(zero);
|
||||||
|
fill_gpu(zero, out, s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Init checks and prep
|
||||||
|
|
||||||
|
int M = a_pre.shape(-2);
|
||||||
|
int N = b_pre.shape(-1);
|
||||||
|
int K = a_pre.shape(-1);
|
||||||
|
|
||||||
|
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||||
|
// the arrays
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
|
||||||
|
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
|
||||||
|
|
||||||
|
for (auto& temp : copies) {
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Check and collapse batch dimensions
|
||||||
|
|
||||||
|
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||||
|
|
||||||
|
auto batch_count = out.size() / (M * N);
|
||||||
|
|
||||||
|
// Collapse batches into M if needed
|
||||||
|
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||||
|
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||||
|
b_batch_strides.back() == 0) {
|
||||||
|
M *= batch_shape.back();
|
||||||
|
batch_count = 1;
|
||||||
|
|
||||||
|
a_batch_strides = {0};
|
||||||
|
b_batch_strides = {0};
|
||||||
|
batch_shape = {1};
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Invoke cublasLt
|
||||||
|
|
||||||
|
cu::MatMul matmul(
|
||||||
|
encoder.device(),
|
||||||
|
a.dtype(),
|
||||||
|
a_transposed,
|
||||||
|
M,
|
||||||
|
K,
|
||||||
|
lda,
|
||||||
|
b_transposed,
|
||||||
|
K,
|
||||||
|
N,
|
||||||
|
ldb,
|
||||||
|
batch_shape.back(),
|
||||||
|
a_batch_strides.back(),
|
||||||
|
b_batch_strides.back());
|
||||||
|
|
||||||
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
|
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
||||||
|
matmul.run(
|
||||||
|
encoder,
|
||||||
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
|
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||||
|
b.data<int8_t>() + b.itemsize() * b_it.loc);
|
||||||
|
a_it.step();
|
||||||
|
b_it.step();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("AddMM::eval_gpu");
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
assert(inputs.size() == 3);
|
||||||
|
auto& a_pre = inputs[0];
|
||||||
|
auto& b_pre = inputs[1];
|
||||||
|
auto& c_pre = inputs[2];
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Init checks and prep
|
||||||
|
|
||||||
|
int M = a_pre.shape(-2);
|
||||||
|
int N = b_pre.shape(-1);
|
||||||
|
int K = a_pre.shape(-1);
|
||||||
|
|
||||||
|
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||||
|
// the arrays
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
|
||||||
|
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
|
||||||
|
auto [c_transposed, ldc, c] = check_transpose(copies, s, c_pre);
|
||||||
|
|
||||||
|
for (auto& temp : copies) {
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Check and collapse batch dimensions
|
||||||
|
|
||||||
|
auto [batch_shape, a_batch_strides, b_batch_strides, c_batch_strides] =
|
||||||
|
collapse_batches(a, b, c);
|
||||||
|
|
||||||
|
auto batch_count = out.size() / (M * N);
|
||||||
|
|
||||||
|
// Collapse batches into M if needed
|
||||||
|
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||||
|
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||||
|
c_batch_strides.back() == M * c.strides()[c.ndim() - 2] &&
|
||||||
|
b_batch_strides.back() == 0) {
|
||||||
|
M *= batch_shape.back();
|
||||||
|
batch_count = 1;
|
||||||
|
|
||||||
|
a_batch_strides = {0};
|
||||||
|
b_batch_strides = {0};
|
||||||
|
c_batch_strides = {0};
|
||||||
|
batch_shape = {1};
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Invoke cublasLt
|
||||||
|
|
||||||
|
cu::MatMul matmul(
|
||||||
|
encoder.device(),
|
||||||
|
a.dtype(),
|
||||||
|
a_transposed,
|
||||||
|
M,
|
||||||
|
K,
|
||||||
|
lda,
|
||||||
|
b_transposed,
|
||||||
|
K,
|
||||||
|
N,
|
||||||
|
ldb,
|
||||||
|
c_transposed,
|
||||||
|
ldc,
|
||||||
|
batch_shape.back(),
|
||||||
|
a_batch_strides.back(),
|
||||||
|
b_batch_strides.back(),
|
||||||
|
c_batch_strides.back());
|
||||||
|
|
||||||
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
|
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||||
|
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
||||||
|
matmul.run(
|
||||||
|
encoder,
|
||||||
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
|
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||||
|
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
||||||
|
c.data<int8_t>() + c.itemsize() * c_it.loc,
|
||||||
|
alpha_,
|
||||||
|
beta_);
|
||||||
|
a_it.step();
|
||||||
|
b_it.step();
|
||||||
|
c_it.step();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
11
mlx/backend/cuda/no_cuda.cpp
Normal file
11
mlx/backend/cuda/no_cuda.cpp
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/dtype_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/kernels/arange.cuh"
|
#include "mlx/backend/cuda/kernels/arange.cuh"
|
||||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
@@ -43,111 +43,73 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool fast::ScaledDotProductAttention::use_fallback(
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
bool has_mask,
|
||||||
|
bool has_arr_mask,
|
||||||
|
bool do_causal,
|
||||||
|
Stream s) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
void func::eval_gpu( \
|
void func::eval_gpu( \
|
||||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define NO_GPU_USE_FALLBACK(func) \
|
||||||
|
bool func::use_fallback(Stream s) { \
|
||||||
|
return true; \
|
||||||
|
} \
|
||||||
|
NO_GPU_MULTI(func)
|
||||||
|
|
||||||
#define NO_GPU(func) \
|
#define NO_GPU(func) \
|
||||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||||
}
|
}
|
||||||
|
|
||||||
NO_GPU(Abs)
|
|
||||||
NO_GPU(Add)
|
|
||||||
NO_GPU(AddMM)
|
|
||||||
NO_GPU(ArcCos)
|
|
||||||
NO_GPU(ArcCosh)
|
|
||||||
NO_GPU(ArcSin)
|
|
||||||
NO_GPU(ArcSinh)
|
|
||||||
NO_GPU(ArcTan)
|
|
||||||
NO_GPU(ArcTan2)
|
|
||||||
NO_GPU(ArcTanh)
|
|
||||||
NO_GPU(ArgPartition)
|
NO_GPU(ArgPartition)
|
||||||
NO_GPU(ArgReduce)
|
NO_GPU(ArgReduce)
|
||||||
NO_GPU(ArgSort)
|
|
||||||
NO_GPU(BitwiseBinary)
|
|
||||||
NO_GPU(BitwiseInvert)
|
|
||||||
NO_GPU(BlockMaskedMM)
|
NO_GPU(BlockMaskedMM)
|
||||||
NO_GPU(Ceil)
|
|
||||||
NO_GPU_MULTI(Compiled)
|
NO_GPU_MULTI(Compiled)
|
||||||
NO_GPU(Conjugate)
|
|
||||||
NO_GPU(Convolution)
|
NO_GPU(Convolution)
|
||||||
NO_GPU(Cos)
|
|
||||||
NO_GPU(Cosh)
|
|
||||||
NO_GPU(Divide)
|
|
||||||
NO_GPU_MULTI(DivMod)
|
NO_GPU_MULTI(DivMod)
|
||||||
NO_GPU(DynamicSlice)
|
NO_GPU(DynamicSlice)
|
||||||
NO_GPU(DynamicSliceUpdate)
|
NO_GPU(DynamicSliceUpdate)
|
||||||
NO_GPU(Remainder)
|
|
||||||
NO_GPU(Equal)
|
|
||||||
NO_GPU(Erf)
|
|
||||||
NO_GPU(ErfInv)
|
|
||||||
NO_GPU(Exp)
|
|
||||||
NO_GPU(Expm1)
|
|
||||||
NO_GPU(FFT)
|
NO_GPU(FFT)
|
||||||
NO_GPU(Floor)
|
|
||||||
NO_GPU(Gather)
|
NO_GPU(Gather)
|
||||||
NO_GPU(GatherAxis)
|
NO_GPU(GatherAxis)
|
||||||
NO_GPU(GatherMM)
|
NO_GPU(GatherMM)
|
||||||
NO_GPU(GatherQMM)
|
NO_GPU(GatherQMM)
|
||||||
NO_GPU(Greater)
|
|
||||||
NO_GPU(GreaterEqual)
|
|
||||||
NO_GPU(Hadamard)
|
NO_GPU(Hadamard)
|
||||||
NO_GPU(Imag)
|
|
||||||
NO_GPU(Less)
|
|
||||||
NO_GPU(LessEqual)
|
|
||||||
NO_GPU(Load)
|
NO_GPU(Load)
|
||||||
NO_GPU(Log)
|
|
||||||
NO_GPU(Log1p)
|
|
||||||
NO_GPU(LogicalNot)
|
|
||||||
NO_GPU(LogicalAnd)
|
|
||||||
NO_GPU(LogicalOr)
|
|
||||||
NO_GPU(LogAddExp)
|
|
||||||
NO_GPU(LogSumExp)
|
NO_GPU(LogSumExp)
|
||||||
NO_GPU_MULTI(LUF)
|
NO_GPU_MULTI(LUF)
|
||||||
NO_GPU(Matmul)
|
|
||||||
NO_GPU(Maximum)
|
|
||||||
NO_GPU(Minimum)
|
|
||||||
NO_GPU(Multiply)
|
|
||||||
NO_GPU(Negative)
|
|
||||||
NO_GPU(NotEqual)
|
|
||||||
NO_GPU(Partition)
|
NO_GPU(Partition)
|
||||||
NO_GPU(Power)
|
|
||||||
NO_GPU_MULTI(QRF)
|
NO_GPU_MULTI(QRF)
|
||||||
NO_GPU(QuantizedMatmul)
|
NO_GPU(QuantizedMatmul)
|
||||||
NO_GPU(RandomBits)
|
|
||||||
NO_GPU(Real)
|
|
||||||
NO_GPU(Reduce)
|
NO_GPU(Reduce)
|
||||||
NO_GPU(Round)
|
|
||||||
NO_GPU(Scan)
|
NO_GPU(Scan)
|
||||||
NO_GPU(Scatter)
|
NO_GPU(Scatter)
|
||||||
NO_GPU(ScatterAxis)
|
NO_GPU(ScatterAxis)
|
||||||
NO_GPU(Select)
|
NO_GPU(Select)
|
||||||
NO_GPU(Sigmoid)
|
|
||||||
NO_GPU(Sign)
|
|
||||||
NO_GPU(Sin)
|
|
||||||
NO_GPU(Sinh)
|
|
||||||
NO_GPU(SliceUpdate)
|
NO_GPU(SliceUpdate)
|
||||||
NO_GPU(Softmax)
|
NO_GPU(Softmax)
|
||||||
NO_GPU(Sort)
|
|
||||||
NO_GPU(Square)
|
|
||||||
NO_GPU(Sqrt)
|
|
||||||
NO_GPU(Subtract)
|
|
||||||
NO_GPU_MULTI(SVD)
|
NO_GPU_MULTI(SVD)
|
||||||
NO_GPU(Tan)
|
|
||||||
NO_GPU(Tanh)
|
|
||||||
NO_GPU(Inverse)
|
NO_GPU(Inverse)
|
||||||
NO_GPU(Cholesky)
|
NO_GPU(Cholesky)
|
||||||
|
NO_GPU_MULTI(Eig)
|
||||||
NO_GPU_MULTI(Eigh)
|
NO_GPU_MULTI(Eigh)
|
||||||
|
|
||||||
namespace fast {
|
namespace fast {
|
||||||
NO_GPU_MULTI(LayerNorm)
|
NO_GPU_USE_FALLBACK(LayerNorm)
|
||||||
NO_GPU_MULTI(LayerNormVJP)
|
NO_GPU_MULTI(LayerNormVJP)
|
||||||
NO_GPU_MULTI(RMSNorm)
|
NO_GPU_USE_FALLBACK(RMSNorm)
|
||||||
NO_GPU_MULTI(RMSNormVJP)
|
NO_GPU_MULTI(RMSNormVJP)
|
||||||
NO_GPU_MULTI(RoPE)
|
NO_GPU_USE_FALLBACK(RoPE)
|
||||||
NO_GPU(ScaledDotProductAttention)
|
NO_GPU(ScaledDotProductAttention)
|
||||||
NO_GPU_MULTI(AffineQuantize)
|
NO_GPU_MULTI(AffineQuantize)
|
||||||
NO_GPU_MULTI(CustomKernel)
|
NO_GPU_MULTI(CustomKernel)
|
||||||
|
|||||||
181
mlx/backend/cuda/random.cu
Normal file
181
mlx/backend/cuda/random.cu
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
__constant__ constexpr uint32_t rotations[2][4] = {
|
||||||
|
{13, 15, 26, 6},
|
||||||
|
{17, 29, 16, 24}};
|
||||||
|
|
||||||
|
union rbits {
|
||||||
|
uint2 val;
|
||||||
|
uint8_t bytes[2][4];
|
||||||
|
};
|
||||||
|
|
||||||
|
__device__ rbits threefry2x32_hash(uint2 key, uint2 count) {
|
||||||
|
uint32_t ks[] = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
|
||||||
|
|
||||||
|
rbits v;
|
||||||
|
v.val.x = count.x + ks[0];
|
||||||
|
v.val.y = count.y + ks[1];
|
||||||
|
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
for (auto r : rotations[i % 2]) {
|
||||||
|
v.val.x += v.val.y;
|
||||||
|
v.val.y = (v.val.y << r) | (v.val.y >> (32 - r));
|
||||||
|
v.val.y ^= v.val.x;
|
||||||
|
}
|
||||||
|
v.val.x += ks[(i + 1) % 3];
|
||||||
|
v.val.y += ks[(i + 2) % 3] + i + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void rbitsc(
|
||||||
|
const uint32_t* keys,
|
||||||
|
uint8_t* out,
|
||||||
|
dim3 grid_dims,
|
||||||
|
bool odd,
|
||||||
|
uint32_t bytes_per_key) {
|
||||||
|
uint2 index{
|
||||||
|
blockIdx.x * blockDim.x + threadIdx.x,
|
||||||
|
blockIdx.y * blockDim.y + threadIdx.y};
|
||||||
|
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kidx = 2 * index.x;
|
||||||
|
auto key = uint2{keys[kidx], keys[kidx + 1]};
|
||||||
|
auto half_size = grid_dims.y - odd;
|
||||||
|
out += index.x * bytes_per_key;
|
||||||
|
bool drop_last = odd && (index.y == half_size);
|
||||||
|
auto bits = threefry2x32_hash(
|
||||||
|
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
|
||||||
|
size_t idx = size_t(index.y) << 2;
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
out[idx + i] = bits.bytes[0][i];
|
||||||
|
}
|
||||||
|
if (!drop_last) {
|
||||||
|
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
|
||||||
|
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||||
|
int edge_bytes = (bytes_per_key % 4);
|
||||||
|
for (int i = 0; i < edge_bytes; ++i) {
|
||||||
|
out[idx + i] = bits.bytes[1][i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
out[idx + i] = bits.bytes[1][i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void rbits(
|
||||||
|
const uint32_t* keys,
|
||||||
|
uint8_t* out,
|
||||||
|
dim3 grid_dims,
|
||||||
|
bool odd,
|
||||||
|
uint32_t bytes_per_key,
|
||||||
|
int32_t ndim,
|
||||||
|
const __grid_constant__ Shape key_shape,
|
||||||
|
const __grid_constant__ Strides key_strides) {
|
||||||
|
uint2 index{
|
||||||
|
blockIdx.x * blockDim.x + threadIdx.x,
|
||||||
|
blockIdx.y * blockDim.y + threadIdx.y};
|
||||||
|
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kidx = 2 * index.x;
|
||||||
|
auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim);
|
||||||
|
auto k2_elem =
|
||||||
|
elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim);
|
||||||
|
auto key = uint2{keys[k1_elem], keys[k2_elem]};
|
||||||
|
auto half_size = grid_dims.y - odd;
|
||||||
|
out += size_t(index.x) * bytes_per_key;
|
||||||
|
bool drop_last = odd && (index.y == half_size);
|
||||||
|
auto bits = threefry2x32_hash(
|
||||||
|
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
|
||||||
|
size_t idx = size_t(index.y) << 2;
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
out[idx + i] = bits.bytes[0][i];
|
||||||
|
}
|
||||||
|
if (!drop_last) {
|
||||||
|
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
|
||||||
|
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||||
|
int edge_bytes = (bytes_per_key % 4);
|
||||||
|
for (int i = 0; i < edge_bytes; ++i) {
|
||||||
|
out[idx + i] = bits.bytes[1][i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
out[idx + i] = bits.bytes[1][i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("RandomBits::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
// keys has shape (N1, ..., NK, 2)
|
||||||
|
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||||
|
auto& keys = inputs[0];
|
||||||
|
uint32_t num_keys = keys.size() / 2;
|
||||||
|
|
||||||
|
uint32_t elems_per_key = out.size() / num_keys;
|
||||||
|
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||||
|
uint32_t half_size = out_per_key / 2;
|
||||||
|
bool odd = out_per_key % 2;
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(keys);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
dim3 grid_dims{num_keys, half_size + odd};
|
||||||
|
dim3 block_dims = get_block_dims(grid_dims.x, grid_dims.y, 1);
|
||||||
|
dim3 num_blocks{
|
||||||
|
cuda::ceil_div(grid_dims.x, block_dims.x),
|
||||||
|
cuda::ceil_div(grid_dims.y, block_dims.y)};
|
||||||
|
if (keys.flags().row_contiguous) {
|
||||||
|
cu::rbitsc<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
keys.data<uint32_t>(),
|
||||||
|
out.data<uint8_t>(),
|
||||||
|
grid_dims,
|
||||||
|
odd,
|
||||||
|
bytes_per_key);
|
||||||
|
} else {
|
||||||
|
cu::rbits<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
keys.data<uint32_t>(),
|
||||||
|
out.data<uint8_t>(),
|
||||||
|
grid_dims,
|
||||||
|
odd,
|
||||||
|
bytes_per_key,
|
||||||
|
keys.ndim(),
|
||||||
|
const_param(keys.shape()),
|
||||||
|
const_param(keys.strides()));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,7 +1,11 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/slicing.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/gpu/slicing.h"
|
#include "mlx/backend/gpu/slicing.h"
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void concatenate_gpu(
|
void concatenate_gpu(
|
||||||
@@ -9,7 +13,29 @@ void concatenate_gpu(
|
|||||||
array& out,
|
array& out,
|
||||||
int axis,
|
int axis,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
throw std::runtime_error("concatenate_gpu not implemented in CUDA backend.");
|
std::vector<int> sizes;
|
||||||
|
sizes.push_back(0);
|
||||||
|
for (auto& p : inputs) {
|
||||||
|
sizes.push_back(p.shape(axis));
|
||||||
|
}
|
||||||
|
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto strides = out.strides();
|
||||||
|
auto flags = out.flags();
|
||||||
|
flags.row_contiguous = false;
|
||||||
|
flags.col_contiguous = false;
|
||||||
|
flags.contiguous = false;
|
||||||
|
// TODO: Handle concurrent outputs:
|
||||||
|
// https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816
|
||||||
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
|
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||||
|
size_t data_offset = strides[axis] * sizes[i];
|
||||||
|
out_slice.copy_shared_buffer(
|
||||||
|
out, strides, flags, out_slice.size(), data_offset);
|
||||||
|
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
180
mlx/backend/cuda/sort.cu
Normal file
180
mlx/backend/cuda/sort.cu
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <thrust/device_ptr.h>
|
||||||
|
#include <thrust/transform.h>
|
||||||
|
#include <cub/device/device_segmented_sort.cuh>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ModOp {
|
||||||
|
T divisor;
|
||||||
|
__device__ T operator()(T x) {
|
||||||
|
return x % divisor;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// We can not use any op in eval, make an utility.
|
||||||
|
array swapaxes_in_eval(const array& in, int axis1, int axis2) {
|
||||||
|
std::vector<int> axes(in.ndim());
|
||||||
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
|
std::swap(axes[axis1], axes[axis2]);
|
||||||
|
// TODO: Share the code with Transpose::eval.
|
||||||
|
Shape shape(axes.size());
|
||||||
|
Strides strides(in.ndim());
|
||||||
|
for (size_t ax = 0; ax < axes.size(); ++ax) {
|
||||||
|
shape[ax] = in.shape()[axes[ax]];
|
||||||
|
strides[ax] = in.strides()[axes[ax]];
|
||||||
|
}
|
||||||
|
auto flags = in.flags();
|
||||||
|
if (flags.contiguous) {
|
||||||
|
auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides);
|
||||||
|
flags.row_contiguous = row_contiguous;
|
||||||
|
flags.col_contiguous = col_contiguous;
|
||||||
|
}
|
||||||
|
array out(shape, in.dtype(), nullptr, {});
|
||||||
|
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void segmented_sort_pairs(cu::CommandEncoder& encoder, Args&&... args) {
|
||||||
|
// Allocate temporary storage.
|
||||||
|
size_t size;
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cub::DeviceSegmentedSort::StableSortPairs(nullptr, size, args...));
|
||||||
|
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
// Run op.
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
|
||||||
|
temp.data<void>(), size, args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
||||||
|
// Allocate temporary storage.
|
||||||
|
size_t size;
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cub::DeviceSegmentedSort::StableSortKeys(nullptr, size, args...));
|
||||||
|
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
// Run op.
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
|
||||||
|
temp.data<void>(), size, args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||||
|
array out = out_;
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
if (axis < 0) {
|
||||||
|
axis += in.ndim();
|
||||||
|
}
|
||||||
|
int nsort = in.shape(axis);
|
||||||
|
int nsegments = in.data_size() / nsort;
|
||||||
|
int last_dim = in.ndim() - 1;
|
||||||
|
|
||||||
|
// If we are not sorting the innermost dimension of a contiguous array,
|
||||||
|
// transpose and make a copy.
|
||||||
|
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
|
||||||
|
if (!is_segmented_sort) {
|
||||||
|
array trans = swapaxes_in_eval(in, axis, last_dim);
|
||||||
|
in = array(trans.shape(), trans.dtype(), nullptr, {});
|
||||||
|
copy_gpu(trans, in, CopyType::General, s);
|
||||||
|
encoder.add_temporary(in);
|
||||||
|
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||||
|
encoder.add_temporary(out);
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||||
|
using Type = cuda_type_t<CTYPE>;
|
||||||
|
auto offsets = thrust::make_transform_iterator(
|
||||||
|
thrust::make_counting_iterator(0),
|
||||||
|
[nsort] __device__(int i) { return i * nsort; });
|
||||||
|
if (argsort) {
|
||||||
|
// Indices in the sorted dimension.
|
||||||
|
array indices(
|
||||||
|
allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||||
|
encoder.add_temporary(indices);
|
||||||
|
thrust::transform(
|
||||||
|
cu::thrust_policy(stream),
|
||||||
|
thrust::counting_iterator<uint32_t>(0),
|
||||||
|
thrust::counting_iterator<uint32_t>(indices.data_size()),
|
||||||
|
thrust::device_pointer_cast(indices.data<uint32_t>()),
|
||||||
|
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
|
||||||
|
|
||||||
|
// In argsort though we don't need the result of sorted values, the
|
||||||
|
// API requires us to provide an array to store it.
|
||||||
|
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
|
||||||
|
encoder.add_temporary(discard);
|
||||||
|
|
||||||
|
segmented_sort_pairs(
|
||||||
|
encoder,
|
||||||
|
in.data<Type>(),
|
||||||
|
discard.data<Type>(),
|
||||||
|
indices.data<uint32_t>(),
|
||||||
|
out.data<uint32_t>(),
|
||||||
|
in.data_size(),
|
||||||
|
nsegments,
|
||||||
|
offsets,
|
||||||
|
offsets + 1,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
segmented_sort(
|
||||||
|
encoder,
|
||||||
|
in.data<Type>(),
|
||||||
|
out.data<Type>(),
|
||||||
|
in.data_size(),
|
||||||
|
nsegments,
|
||||||
|
offsets,
|
||||||
|
offsets + 1,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"CUDA backend does not support sorting complex numbers");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!is_segmented_sort) {
|
||||||
|
// Swap the sorted axis back.
|
||||||
|
// TODO: Do in-place transpose instead of using a temporary out array.
|
||||||
|
copy_gpu(swapaxes_in_eval(out, axis, last_dim), out_, CopyType::General, s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("ArgSort::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
gpu_sort(stream(), inputs[0], out, axis_, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Sort::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
196
mlx/backend/cuda/unary.cu
Normal file
196
mlx/backend/cuda/unary.cu
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/unary.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/unary_ops.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <thrust/device_ptr.h>
|
||||||
|
#include <thrust/transform.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out>
|
||||||
|
constexpr bool supports_unary_op() {
|
||||||
|
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
|
||||||
|
std::is_same_v<Op, Sign>) {
|
||||||
|
return std::is_same_v<In, Out>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcCosh> ||
|
||||||
|
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
||||||
|
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||||
|
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||||
|
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
|
||||||
|
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||||
|
std::is_same_v<Op, Log10> || std::is_same_v<Op, Sigmoid> ||
|
||||||
|
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
||||||
|
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||||
|
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||||
|
!std::is_same_v<In, bool>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor> ||
|
||||||
|
std::is_same_v<Op, Square>) {
|
||||||
|
return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, Conjugate>) {
|
||||||
|
return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, Cos> || std::is_same_v<Op, Cosh> ||
|
||||||
|
std::is_same_v<Op, Exp> || std::is_same_v<Op, Round> ||
|
||||||
|
std::is_same_v<Op, Sin> || std::is_same_v<Op, Sinh> ||
|
||||||
|
std::is_same_v<Op, Tan> || std::is_same_v<Op, Tanh>) {
|
||||||
|
return std::is_same_v<In, Out> &&
|
||||||
|
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
|
||||||
|
return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, LogicalNot>) {
|
||||||
|
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void unary_op_gpu_inplace(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
const std::string& op,
|
||||||
|
const Stream& s) {
|
||||||
|
auto& in = inputs[0];
|
||||||
|
if (in.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, {
|
||||||
|
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
||||||
|
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||||
|
using InType = cuda_type_t<CTYPE_IN>;
|
||||||
|
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||||
|
auto policy = cu::thrust_policy(stream);
|
||||||
|
auto in_ptr = thrust::device_pointer_cast(in.data<InType>());
|
||||||
|
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||||
|
if (in.flags().contiguous) {
|
||||||
|
thrust::transform(
|
||||||
|
policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op());
|
||||||
|
} else {
|
||||||
|
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||||
|
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
|
||||||
|
in_ptr, in.data_size(), shape, strides);
|
||||||
|
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Can not do unary op {} on input of {} with output of {}.",
|
||||||
|
op,
|
||||||
|
dtype_to_string(in.dtype()),
|
||||||
|
dtype_to_string(out.dtype())));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void unary_op_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
const std::string& op,
|
||||||
|
const Stream& s) {
|
||||||
|
set_unary_output_data(inputs[0], out);
|
||||||
|
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define UNARY_GPU(func) \
|
||||||
|
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||||
|
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||||
|
auto& s = out.primitive().stream(); \
|
||||||
|
unary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
||||||
|
}
|
||||||
|
|
||||||
|
UNARY_GPU(Abs)
|
||||||
|
UNARY_GPU(ArcCos)
|
||||||
|
UNARY_GPU(ArcCosh)
|
||||||
|
UNARY_GPU(ArcSin)
|
||||||
|
UNARY_GPU(ArcSinh)
|
||||||
|
UNARY_GPU(ArcTan)
|
||||||
|
UNARY_GPU(ArcTanh)
|
||||||
|
UNARY_GPU(BitwiseInvert)
|
||||||
|
UNARY_GPU(Ceil)
|
||||||
|
UNARY_GPU(Conjugate)
|
||||||
|
UNARY_GPU(Cos)
|
||||||
|
UNARY_GPU(Cosh)
|
||||||
|
UNARY_GPU(Erf)
|
||||||
|
UNARY_GPU(ErfInv)
|
||||||
|
UNARY_GPU(Exp)
|
||||||
|
UNARY_GPU(Expm1)
|
||||||
|
UNARY_GPU(Floor)
|
||||||
|
UNARY_GPU(Imag)
|
||||||
|
UNARY_GPU(Log1p)
|
||||||
|
UNARY_GPU(LogicalNot)
|
||||||
|
UNARY_GPU(Negative)
|
||||||
|
UNARY_GPU(Real)
|
||||||
|
UNARY_GPU(Sigmoid)
|
||||||
|
UNARY_GPU(Sign)
|
||||||
|
UNARY_GPU(Sin)
|
||||||
|
UNARY_GPU(Sinh)
|
||||||
|
UNARY_GPU(Square)
|
||||||
|
UNARY_GPU(Tan)
|
||||||
|
UNARY_GPU(Tanh)
|
||||||
|
|
||||||
|
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Log::eval_gpu");
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
auto op = get_primitive_string(this);
|
||||||
|
switch (base_) {
|
||||||
|
case Base::e:
|
||||||
|
unary_op_gpu<cu::Log>(inputs, out, op, s);
|
||||||
|
break;
|
||||||
|
case Base::two:
|
||||||
|
unary_op_gpu<cu::Log2>(inputs, out, op, s);
|
||||||
|
break;
|
||||||
|
case Base::ten:
|
||||||
|
unary_op_gpu<cu::Log10>(inputs, out, op, s);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Round::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
if (issubdtype(in.dtype(), inexact)) {
|
||||||
|
unary_op_gpu<cu::Round>(inputs, out, get_primitive_string(this), s);
|
||||||
|
} else {
|
||||||
|
// No-op integer types
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Sort::eval_gpu");
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
if (recip_) {
|
||||||
|
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);
|
||||||
|
} else {
|
||||||
|
unary_op_gpu<cu::Sqrt>(inputs, out, "Sqrt", s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
// This file include utilies that are used by C++ code (i.e. .cpp files).
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|||||||
@@ -5,9 +5,17 @@
|
|||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/gpu/slicing.h"
|
#include "mlx/backend/gpu/slicing.h"
|
||||||
|
|
||||||
|
#if defined(MLX_USE_CUDA)
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
|
#if defined(MLX_USE_CUDA)
|
||||||
|
#define MLX_PROFILER_RANGE(message) nvtx3::scoped_range r(message)
|
||||||
|
#else
|
||||||
#define MLX_PROFILER_RANGE(message)
|
#define MLX_PROFILER_RANGE(message)
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
|||||||
@@ -30,141 +30,18 @@ void* Buffer::raw_ptr() {
|
|||||||
|
|
||||||
namespace metal {
|
namespace metal {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
BufferCache::BufferCache(ResidencySet& residency_set)
|
|
||||||
: head_(nullptr),
|
|
||||||
tail_(nullptr),
|
|
||||||
pool_size_(0),
|
|
||||||
residency_set_(residency_set) {}
|
|
||||||
|
|
||||||
BufferCache::~BufferCache() {
|
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
|
||||||
clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
int BufferCache::clear() {
|
|
||||||
int n_release = 0;
|
|
||||||
for (auto& [size, holder] : buffer_pool_) {
|
|
||||||
if (holder->buf) {
|
|
||||||
if (!holder->buf->heap()) {
|
|
||||||
residency_set_.erase(holder->buf);
|
|
||||||
}
|
|
||||||
holder->buf->release();
|
|
||||||
n_release++;
|
|
||||||
}
|
|
||||||
delete holder;
|
|
||||||
}
|
|
||||||
buffer_pool_.clear();
|
|
||||||
pool_size_ = 0;
|
|
||||||
head_ = nullptr;
|
|
||||||
tail_ = nullptr;
|
|
||||||
return n_release;
|
|
||||||
}
|
|
||||||
|
|
||||||
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
|
||||||
// Find the closest buffer in pool
|
|
||||||
MTL::Buffer* pbuf = nullptr;
|
|
||||||
|
|
||||||
auto it = buffer_pool_.lower_bound(size);
|
|
||||||
|
|
||||||
// Make sure we use most of the available memory
|
|
||||||
while (!pbuf && it != buffer_pool_.end() &&
|
|
||||||
it->first < std::min(2 * size, size + 2 * vm_page_size)) {
|
|
||||||
// Collect from the cache
|
|
||||||
pbuf = it->second->buf;
|
|
||||||
|
|
||||||
// Remove from cache
|
|
||||||
remove_from_list(it->second);
|
|
||||||
delete it->second;
|
|
||||||
it = buffer_pool_.erase(it);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pbuf) {
|
|
||||||
pool_size_ -= pbuf->length();
|
|
||||||
}
|
|
||||||
|
|
||||||
return pbuf;
|
|
||||||
}
|
|
||||||
|
|
||||||
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
|
||||||
// Add to cache
|
|
||||||
if (buf) {
|
|
||||||
BufferHolder* bh = new BufferHolder(buf);
|
|
||||||
add_at_head(bh);
|
|
||||||
pool_size_ += buf->length();
|
|
||||||
buffer_pool_.insert({buf->length(), bh});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
|
||||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
|
||||||
return clear();
|
|
||||||
} else {
|
|
||||||
int n_release = 0;
|
|
||||||
size_t total_bytes_freed = 0;
|
|
||||||
|
|
||||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
|
||||||
if (tail_->buf) {
|
|
||||||
total_bytes_freed += tail_->buf->length();
|
|
||||||
if (!tail_->buf->heap()) {
|
|
||||||
residency_set_.erase(tail_->buf);
|
|
||||||
}
|
|
||||||
tail_->buf->release();
|
|
||||||
tail_->buf = nullptr;
|
|
||||||
n_release++;
|
|
||||||
}
|
|
||||||
remove_from_list(tail_);
|
|
||||||
}
|
|
||||||
pool_size_ -= total_bytes_freed;
|
|
||||||
return n_release;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
|
|
||||||
if (!to_add)
|
|
||||||
return;
|
|
||||||
|
|
||||||
if (!head_) {
|
|
||||||
head_ = to_add;
|
|
||||||
tail_ = to_add;
|
|
||||||
} else {
|
|
||||||
head_->prev = to_add;
|
|
||||||
to_add->next = head_;
|
|
||||||
head_ = to_add;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
|
||||||
if (!to_remove) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If in the middle
|
|
||||||
if (to_remove->prev && to_remove->next) {
|
|
||||||
to_remove->prev->next = to_remove->next;
|
|
||||||
to_remove->next->prev = to_remove->prev;
|
|
||||||
} else if (to_remove->prev && to_remove == tail_) { // If tail
|
|
||||||
tail_ = to_remove->prev;
|
|
||||||
tail_->next = nullptr;
|
|
||||||
} else if (to_remove == head_ && to_remove->next) { // If head
|
|
||||||
head_ = to_remove->next;
|
|
||||||
head_->prev = nullptr;
|
|
||||||
} else if (to_remove == head_ && to_remove == tail_) { // If only element
|
|
||||||
head_ = nullptr;
|
|
||||||
tail_ = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
to_remove->prev = nullptr;
|
|
||||||
to_remove->next = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
MetalAllocator::MetalAllocator()
|
MetalAllocator::MetalAllocator()
|
||||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||||
residency_set_(device_),
|
residency_set_(device_),
|
||||||
buffer_cache_(residency_set_) {
|
buffer_cache_(
|
||||||
|
vm_page_size,
|
||||||
|
[](MTL::Buffer* buf) { return buf->length(); },
|
||||||
|
[this](MTL::Buffer* buf) {
|
||||||
|
if (!buf->heap()) {
|
||||||
|
residency_set_.erase(buf);
|
||||||
|
}
|
||||||
|
buf->release();
|
||||||
|
}) {
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
auto memsize = std::get<size_t>(device_info().at("memory_size"));
|
auto memsize = std::get<size_t>(device_info().at("memory_size"));
|
||||||
auto max_rec_size =
|
auto max_rec_size =
|
||||||
@@ -193,6 +70,7 @@ MetalAllocator::~MetalAllocator() {
|
|||||||
if (heap_) {
|
if (heap_) {
|
||||||
heap_->release();
|
heap_->release();
|
||||||
}
|
}
|
||||||
|
buffer_cache_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/buffer_cache.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/resident.h"
|
#include "mlx/backend/metal/resident.h"
|
||||||
|
|
||||||
@@ -14,43 +15,6 @@ namespace mlx::core::metal {
|
|||||||
|
|
||||||
using allocator::Buffer;
|
using allocator::Buffer;
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
class BufferCache {
|
|
||||||
public:
|
|
||||||
BufferCache(ResidencySet& residency_set);
|
|
||||||
~BufferCache();
|
|
||||||
|
|
||||||
MTL::Buffer* reuse_from_cache(size_t size);
|
|
||||||
void recycle_to_cache(MTL::Buffer* buf);
|
|
||||||
int release_cached_buffers(size_t min_bytes_to_free);
|
|
||||||
size_t cache_size() {
|
|
||||||
return pool_size_;
|
|
||||||
}
|
|
||||||
int clear();
|
|
||||||
|
|
||||||
private:
|
|
||||||
struct BufferHolder {
|
|
||||||
public:
|
|
||||||
BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
|
|
||||||
|
|
||||||
BufferHolder* prev;
|
|
||||||
BufferHolder* next;
|
|
||||||
MTL::Buffer* buf;
|
|
||||||
};
|
|
||||||
|
|
||||||
void add_at_head(BufferHolder* to_add);
|
|
||||||
void remove_from_list(BufferHolder* to_remove);
|
|
||||||
|
|
||||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
|
||||||
BufferHolder* head_;
|
|
||||||
BufferHolder* tail_;
|
|
||||||
size_t pool_size_;
|
|
||||||
ResidencySet& residency_set_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
class MetalAllocator : public allocator::Allocator {
|
class MetalAllocator : public allocator::Allocator {
|
||||||
/** Allocator for Metal GPUs. */
|
/** Allocator for Metal GPUs. */
|
||||||
public:
|
public:
|
||||||
@@ -90,7 +54,7 @@ class MetalAllocator : public allocator::Allocator {
|
|||||||
friend MetalAllocator& allocator();
|
friend MetalAllocator& allocator();
|
||||||
|
|
||||||
// Caching allocator
|
// Caching allocator
|
||||||
BufferCache buffer_cache_;
|
BufferCache<MTL::Buffer> buffer_cache_;
|
||||||
|
|
||||||
ResidencySet residency_set_;
|
ResidencySet residency_set_;
|
||||||
|
|
||||||
|
|||||||
@@ -31,13 +31,13 @@ std::string get_kernel_name(
|
|||||||
kname = "ss";
|
kname = "ss";
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::ScalarVector:
|
case BinaryOpType::ScalarVector:
|
||||||
kname = (large ? "sv2" : "sv");
|
kname = "sv";
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::VectorScalar:
|
case BinaryOpType::VectorScalar:
|
||||||
kname = (large ? "vs2" : "vs");
|
kname = "vs";
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::VectorVector:
|
case BinaryOpType::VectorVector:
|
||||||
kname = (large ? "vv2" : "vv");
|
kname = "vv";
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::General:
|
case BinaryOpType::General:
|
||||||
kname = "g";
|
kname = "g";
|
||||||
@@ -51,6 +51,13 @@ std::string get_kernel_name(
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) {
|
||||||
|
if (large) {
|
||||||
|
kname += "2";
|
||||||
|
} else if (work_per_thread > 1) {
|
||||||
|
kname += "n";
|
||||||
|
}
|
||||||
|
}
|
||||||
concatenate(kname, "_", op, type_to_name(a));
|
concatenate(kname, "_", op, type_to_name(a));
|
||||||
return kname;
|
return kname;
|
||||||
}
|
}
|
||||||
@@ -90,7 +97,7 @@ void binary_op_gpu_inplace(
|
|||||||
work_per_thread = large ? 4 : 2;
|
work_per_thread = large ? 4 : 2;
|
||||||
} else {
|
} else {
|
||||||
large = out.data_size() > UINT32_MAX;
|
large = out.data_size() > UINT32_MAX;
|
||||||
work_per_thread = get_work_per_thread(a.dtype());
|
work_per_thread = get_work_per_thread(a.dtype(), out.data_size());
|
||||||
}
|
}
|
||||||
std::string kernel_name =
|
std::string kernel_name =
|
||||||
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
|
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
|
||||||
|
|||||||
@@ -11,8 +11,6 @@
|
|||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
using namespace fmt::literals;
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
inline void build_kernel(
|
inline void build_kernel(
|
||||||
@@ -21,21 +19,12 @@ inline void build_kernel(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<array>& outputs,
|
const std::vector<array>& outputs,
|
||||||
const std::vector<array>& tape,
|
const std::vector<array>& tape,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
bool contiguous,
|
bool contiguous,
|
||||||
int ndim,
|
int ndim,
|
||||||
bool dynamic_dims,
|
bool dynamic_dims,
|
||||||
bool use_big_index = false,
|
bool use_big_index = false,
|
||||||
int work_per_thread = 1) {
|
int work_per_thread = 1) {
|
||||||
// All outputs should have the exact same shape and will be row contiguous
|
|
||||||
auto output_shape = outputs[0].shape();
|
|
||||||
auto output_strides = outputs[0].strides();
|
|
||||||
|
|
||||||
// Constants are scalars that are captured by value and cannot change
|
|
||||||
auto is_constant = [&constant_ids](const array& x) {
|
|
||||||
return constant_ids.find(x.id()) != constant_ids.end();
|
|
||||||
};
|
|
||||||
|
|
||||||
NodeNamer namer;
|
NodeNamer namer;
|
||||||
bool add_indices = false;
|
bool add_indices = false;
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
@@ -45,14 +34,15 @@ inline void build_kernel(
|
|||||||
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
|
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
|
||||||
|
|
||||||
// Add the input arguments
|
// Add the input arguments
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
auto& xname = namer.get_name(x);
|
|
||||||
|
|
||||||
// Skip constants from the input list
|
// Skip constants from the input list
|
||||||
if (is_constant(x)) {
|
if (is_constant(i)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
// Scalars and contiguous need no strides
|
// Scalars and contiguous need no strides
|
||||||
if (!is_scalar(x) && !contiguous) {
|
if (!is_scalar(x) && !contiguous) {
|
||||||
add_indices = true;
|
add_indices = true;
|
||||||
@@ -80,8 +70,6 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
// Add output strides and shape to extract the indices.
|
// Add output strides and shape to extract the indices.
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
os += fmt::format(
|
|
||||||
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
|
||||||
os += fmt::format(
|
os += fmt::format(
|
||||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||||
} else {
|
} else {
|
||||||
@@ -125,7 +113,7 @@ inline void build_kernel(
|
|||||||
auto& x = inputs[i];
|
auto& x = inputs[i];
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
if (is_constant(x)) {
|
if (is_constant(i)) {
|
||||||
auto type_str = get_type_string(x.dtype());
|
auto type_str = get_type_string(x.dtype());
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
print_constant(ss, x);
|
print_constant(ss, x);
|
||||||
@@ -271,11 +259,6 @@ inline void build_kernel(
|
|||||||
void Compiled::eval_gpu(
|
void Compiled::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
// Make the name for the kernel library
|
|
||||||
if (kernel_lib_.empty()) {
|
|
||||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the kernel if someone else built it already
|
// Get the kernel if someone else built it already
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
@@ -290,19 +273,33 @@ void Compiled::eval_gpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
/* contiguous = */ true,
|
/* contiguous = */ true,
|
||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ false,
|
/* dynamic_dims = */ false,
|
||||||
/* use_big_index = */ false,
|
/* use_big_index = */ false,
|
||||||
/* work_per_thread = */ work_per_thread);
|
/* work_per_thread = */ 1);
|
||||||
|
if (work_per_thread > 1) {
|
||||||
|
build_kernel(
|
||||||
|
kernel,
|
||||||
|
kernel_lib_ + "_contiguous_n",
|
||||||
|
inputs_,
|
||||||
|
outputs_,
|
||||||
|
tape_,
|
||||||
|
is_constant_,
|
||||||
|
/* contiguous = */ true,
|
||||||
|
/* ndim = */ 0,
|
||||||
|
/* dynamic_dims = */ false,
|
||||||
|
/* use_big_index = */ false,
|
||||||
|
/* work_per_thread = */ work_per_thread);
|
||||||
|
}
|
||||||
build_kernel(
|
build_kernel(
|
||||||
kernel,
|
kernel,
|
||||||
kernel_lib_ + "_contiguous_large",
|
kernel_lib_ + "_contiguous_large",
|
||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
/* contiguous = */ true,
|
/* contiguous = */ true,
|
||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ false,
|
/* dynamic_dims = */ false,
|
||||||
@@ -315,7 +312,7 @@ void Compiled::eval_gpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
/* contiguous = */ false,
|
/* contiguous = */ false,
|
||||||
/* ndim = */ i,
|
/* ndim = */ i,
|
||||||
/* dynamic_dims = */ false,
|
/* dynamic_dims = */ false,
|
||||||
@@ -328,7 +325,7 @@ void Compiled::eval_gpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
/* contiguous = */ false,
|
/* contiguous = */ false,
|
||||||
/* ndim = */ i,
|
/* ndim = */ i,
|
||||||
/* dynamic_dims = */ false,
|
/* dynamic_dims = */ false,
|
||||||
@@ -342,7 +339,7 @@ void Compiled::eval_gpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
/* contiguous = */ false,
|
/* contiguous = */ false,
|
||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ true,
|
/* dynamic_dims = */ true,
|
||||||
@@ -354,7 +351,7 @@ void Compiled::eval_gpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
/* contiguous = */ false,
|
/* contiguous = */ false,
|
||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ true,
|
/* dynamic_dims = */ true,
|
||||||
@@ -363,81 +360,32 @@ void Compiled::eval_gpu(
|
|||||||
return kernel;
|
return kernel;
|
||||||
});
|
});
|
||||||
|
|
||||||
// Figure out which kernel we are using
|
|
||||||
auto& output_shape = outputs[0].shape();
|
|
||||||
auto contiguous = compiled_check_contiguity(inputs, output_shape);
|
|
||||||
|
|
||||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||||
// handle all broadcasting.
|
// handle all broadcasting.
|
||||||
std::vector<Strides> initial_strides;
|
auto [contiguous, shape, strides] =
|
||||||
initial_strides.push_back(outputs[0].strides());
|
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||||
Shape shape;
|
|
||||||
std::vector<Strides> strides;
|
|
||||||
if (!contiguous) {
|
|
||||||
for (int i = 0; i < inputs.size(); i++) {
|
|
||||||
// Skip constants.
|
|
||||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto& x = inputs[i];
|
|
||||||
|
|
||||||
// Skip scalar inputs.
|
// Whether to use large index.
|
||||||
if (is_scalar(x)) {
|
bool large = compiled_use_large_index(inputs, outputs, contiguous);
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Broadcast the inputs to the output shape.
|
|
||||||
Strides xstrides;
|
|
||||||
int j = 0;
|
|
||||||
for (; j < output_shape.size() - x.ndim(); j++) {
|
|
||||||
if (output_shape[j] == 1) {
|
|
||||||
xstrides.push_back(outputs[0].strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
|
||||||
if (x.shape(i) == 1) {
|
|
||||||
if (output_shape[j] == 1) {
|
|
||||||
xstrides.push_back(outputs[0].strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(x.strides()[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
initial_strides.push_back(std::move(xstrides));
|
|
||||||
}
|
|
||||||
std::tie(shape, strides) =
|
|
||||||
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool large;
|
|
||||||
if (contiguous) {
|
|
||||||
size_t max_size = 0;
|
|
||||||
for (auto& in : inputs) {
|
|
||||||
max_size = std::max(max_size, in.data_size());
|
|
||||||
}
|
|
||||||
large = (max_size > UINT32_MAX);
|
|
||||||
} else {
|
|
||||||
size_t max_size = 0;
|
|
||||||
for (auto& o : outputs) {
|
|
||||||
max_size = std::max(max_size, o.size());
|
|
||||||
}
|
|
||||||
large = (max_size > UINT32_MAX);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the kernel from the lib
|
// Get the kernel from the lib
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
bool dynamic = ndim >= 8;
|
bool dynamic = ndim >= 8;
|
||||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||||
|
int work_per_thread = 1;
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
if (dynamic) {
|
if (dynamic) {
|
||||||
kernel_name += "dynamic";
|
kernel_name += "dynamic";
|
||||||
} else {
|
} else {
|
||||||
kernel_name += std::to_string(shape.size());
|
kernel_name += std::to_string(shape.size());
|
||||||
}
|
}
|
||||||
|
work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
|
||||||
|
} else {
|
||||||
|
work_per_thread =
|
||||||
|
get_work_per_thread(outputs[0].dtype(), outputs[0].data_size());
|
||||||
|
if (work_per_thread > 1 && !large) {
|
||||||
|
kernel_name += "_n";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (large) {
|
if (large) {
|
||||||
kernel_name += "_large";
|
kernel_name += "_large";
|
||||||
@@ -451,7 +399,7 @@ void Compiled::eval_gpu(
|
|||||||
int stride_idx = 1; // idx 0 is the output strides
|
int stride_idx = 1; // idx 0 is the output strides
|
||||||
Strides in_strides;
|
Strides in_strides;
|
||||||
for (int i = 0; i < inputs.size(); i++) {
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
if (is_constant_(i)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto& x = inputs[i];
|
auto& x = inputs[i];
|
||||||
@@ -468,8 +416,7 @@ void Compiled::eval_gpu(
|
|||||||
compute_encoder.set_vector_bytes(in_strides, cnt++);
|
compute_encoder.set_vector_bytes(in_strides, cnt++);
|
||||||
}
|
}
|
||||||
|
|
||||||
compiled_allocate_outputs(
|
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
|
||||||
|
|
||||||
// Put the outputs in
|
// Put the outputs in
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
@@ -478,7 +425,6 @@ void Compiled::eval_gpu(
|
|||||||
|
|
||||||
// Put the output shape and strides in
|
// Put the output shape and strides in
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
compute_encoder.set_vector_bytes(strides[0], cnt++);
|
|
||||||
compute_encoder.set_vector_bytes(shape, cnt++);
|
compute_encoder.set_vector_bytes(shape, cnt++);
|
||||||
} else {
|
} else {
|
||||||
auto size = outputs[0].data_size();
|
auto size = outputs[0].data_size();
|
||||||
@@ -496,7 +442,6 @@ void Compiled::eval_gpu(
|
|||||||
|
|
||||||
// Launch the kernel
|
// Launch the kernel
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
int work_per_thread = get_work_per_thread(outputs[0].dtype());
|
|
||||||
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
|
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
|
||||||
MTL::Size group_dims(
|
MTL::Size group_dims(
|
||||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||||
@@ -509,7 +454,6 @@ void Compiled::eval_gpu(
|
|||||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||||
size_t rest = outputs[0].size() / (dim0 * dim1);
|
size_t rest = outputs[0].size() / (dim0 * dim1);
|
||||||
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
|
|
||||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
int pow2;
|
int pow2;
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
@@ -178,83 +177,6 @@ void explicit_gemm_conv_group_ND_gpu(
|
|||||||
/*copies = */ copies);
|
/*copies = */ copies);
|
||||||
}
|
}
|
||||||
|
|
||||||
void conv_1D_gpu(
|
|
||||||
const Stream& s,
|
|
||||||
metal::Device& d,
|
|
||||||
const array& in,
|
|
||||||
const array& wt,
|
|
||||||
array out,
|
|
||||||
const std::vector<int>& padding,
|
|
||||||
const std::vector<int>& wt_strides,
|
|
||||||
const std::vector<int>& wt_dilation,
|
|
||||||
const std::vector<int>& in_dilation,
|
|
||||||
int groups,
|
|
||||||
bool flip) {
|
|
||||||
// Make conv params
|
|
||||||
MLXConvParams<1> conv_params{
|
|
||||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
|
||||||
/* const int C = */ static_cast<int>(in.shape(2)),
|
|
||||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
|
||||||
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
|
|
||||||
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
|
|
||||||
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
|
|
||||||
/* const int str[NDIM] = */ {wt_strides[0]},
|
|
||||||
/* const int pad[NDIM] = */ {padding[0]},
|
|
||||||
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
|
||||||
/* const int idil[NDIM] = */ {in_dilation[0]},
|
|
||||||
/* const size_t in_strides[NDIM + 2] = */
|
|
||||||
{in.strides()[0], in.strides()[1], in.strides()[2]},
|
|
||||||
/* const size_t wt_strides[NDIM + 2] = */
|
|
||||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
|
||||||
/* const size_t out_strides[NDIM + 2] = */
|
|
||||||
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
|
||||||
/* const int groups = */ groups,
|
|
||||||
/* const bool flip = */ flip};
|
|
||||||
|
|
||||||
// Direct to explicit gemm conv
|
|
||||||
if (groups > 1) {
|
|
||||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
} else {
|
|
||||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void slow_conv_2D_gpu(
|
|
||||||
const Stream& s,
|
|
||||||
metal::Device& d,
|
|
||||||
const array& in,
|
|
||||||
const array& wt,
|
|
||||||
array out,
|
|
||||||
const MLXConvParams<2>& conv_params) {
|
|
||||||
int bm = 16, bn = 8;
|
|
||||||
int tm = 4, tn = 4;
|
|
||||||
|
|
||||||
std::ostringstream kname;
|
|
||||||
kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn
|
|
||||||
<< "_tm" << tm << "_tn" << tn;
|
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
auto kernel = d.get_kernel(kname.str());
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
|
||||||
|
|
||||||
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
|
|
||||||
|
|
||||||
size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm);
|
|
||||||
size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn);
|
|
||||||
size_t grid_dim_z = conv_params.N;
|
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(bm, bn, 1);
|
|
||||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
|
||||||
|
|
||||||
compute_encoder.set_input_array(in, 0);
|
|
||||||
compute_encoder.set_input_array(wt, 1);
|
|
||||||
compute_encoder.set_output_array(out, 2);
|
|
||||||
|
|
||||||
compute_encoder.set_bytes(conv_params, 3);
|
|
||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
void implicit_gemm_conv_2D_gpu(
|
void implicit_gemm_conv_2D_gpu(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@@ -469,6 +391,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
|||||||
// Get channel iteration info
|
// Get channel iteration info
|
||||||
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
||||||
int gemm_k_iters = channel_k_iters;
|
int gemm_k_iters = channel_k_iters;
|
||||||
|
bool align_C = conv_params.C % bk == 0;
|
||||||
|
|
||||||
// Fix host side helper params
|
// Fix host side helper params
|
||||||
int sign = (conv_params.flip ? -1 : 1);
|
int sign = (conv_params.flip ? -1 : 1);
|
||||||
@@ -497,14 +420,33 @@ void implicit_gemm_conv_2D_general_gpu(
|
|||||||
/* const int swizzle_log = */ swizzle_log};
|
/* const int swizzle_log = */ swizzle_log};
|
||||||
|
|
||||||
// Determine kernel
|
// Determine kernel
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm
|
kname.reserve(64);
|
||||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
concatenate(
|
||||||
|
kname,
|
||||||
|
"implicit_gemm_conv_2d_general_",
|
||||||
|
type_to_name(out),
|
||||||
|
"_bm",
|
||||||
|
bm,
|
||||||
|
"_bn",
|
||||||
|
bn,
|
||||||
|
"_bk",
|
||||||
|
bk,
|
||||||
|
"_wm",
|
||||||
|
wm,
|
||||||
|
"_wn",
|
||||||
|
wn);
|
||||||
|
std::string hash_name;
|
||||||
|
hash_name.reserve(64);
|
||||||
|
concatenate(hash_name, kname, "_alC_", align_C);
|
||||||
|
metal::MTLFCList func_consts = {
|
||||||
|
{&align_C, MTL::DataType::DataTypeBool, 200},
|
||||||
|
};
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel =
|
auto kernel = get_steel_conv_general_kernel(
|
||||||
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
|
d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
// Deduce grid launch dimensions
|
// Deduce grid launch dimensions
|
||||||
@@ -755,7 +697,7 @@ void depthwise_conv_2D_gpu(
|
|||||||
std::string hash_name = kname.str();
|
std::string hash_name = kname.str();
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
compute_encoder.set_input_array(in, 0);
|
compute_encoder.set_input_array(in, 0);
|
||||||
@@ -771,6 +713,143 @@ void depthwise_conv_2D_gpu(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void dispatch_conv_2D_gpu(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const MLXConvParams<2>& conv_params,
|
||||||
|
std::vector<array>& copies) {
|
||||||
|
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
|
||||||
|
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||||
|
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||||
|
|
||||||
|
if (is_idil_one && conv_params.groups > 1) {
|
||||||
|
const int C_per_group = conv_params.C / conv_params.groups;
|
||||||
|
const int O_per_group = conv_params.O / conv_params.groups;
|
||||||
|
|
||||||
|
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
|
||||||
|
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
|
||||||
|
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
|
||||||
|
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
|
||||||
|
conv_params.wt_strides[1] == conv_params.wS[1] &&
|
||||||
|
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
|
||||||
|
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||||
|
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||||
|
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
} else {
|
||||||
|
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct to winograd conv
|
||||||
|
bool inp_large =
|
||||||
|
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096;
|
||||||
|
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
||||||
|
bool out_large =
|
||||||
|
(conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256;
|
||||||
|
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||||
|
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||||
|
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
||||||
|
channels_large) {
|
||||||
|
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct to implicit gemm conv
|
||||||
|
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
|
||||||
|
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
|
||||||
|
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
|
||||||
|
else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) {
|
||||||
|
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct to explicit gemm conv
|
||||||
|
else {
|
||||||
|
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void conv_1D_gpu(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation,
|
||||||
|
const std::vector<int>& in_dilation,
|
||||||
|
int groups,
|
||||||
|
bool flip,
|
||||||
|
std::vector<array>& copies) {
|
||||||
|
bool is_idil_one = in_dilation[0] == 1;
|
||||||
|
int C = in.shape(2);
|
||||||
|
int O = wt.shape(0);
|
||||||
|
const int C_per_group = in.shape(2) / groups;
|
||||||
|
const int O_per_group = wt.shape(0) / groups;
|
||||||
|
|
||||||
|
// Direct to implicit gemm conv
|
||||||
|
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||||
|
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||||
|
MLXConvParams<2> conv_params{
|
||||||
|
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||||
|
/* const int C = */ C,
|
||||||
|
/* const int O = */ O,
|
||||||
|
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1)), 1},
|
||||||
|
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1)), 1},
|
||||||
|
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1)), 1},
|
||||||
|
/* const int str[NDIM] = */ {wt_strides[0], 1},
|
||||||
|
/* const int pad[NDIM] = */ {padding[0], 0},
|
||||||
|
/* const int kdil[NDIM] = */ {wt_dilation[0], 1},
|
||||||
|
/* const int idil[NDIM] = */ {in_dilation[0], 1},
|
||||||
|
/* const size_t in_strides[NDIM + 2] = */
|
||||||
|
{in.strides()[0], in.strides()[1], 0, in.strides()[2]},
|
||||||
|
/* const size_t wt_strides[NDIM + 2] = */
|
||||||
|
{wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]},
|
||||||
|
/* const size_t out_strides[NDIM + 2] = */
|
||||||
|
{out.strides()[0], out.strides()[1], 0, out.strides()[2]},
|
||||||
|
/* const int groups = */ groups,
|
||||||
|
/* const bool flip = */ flip};
|
||||||
|
|
||||||
|
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make conv params
|
||||||
|
MLXConvParams<1> conv_params{
|
||||||
|
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||||
|
/* const int C = */ static_cast<int>(in.shape(2)),
|
||||||
|
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||||
|
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
|
||||||
|
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
|
||||||
|
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
|
||||||
|
/* const int str[NDIM] = */ {wt_strides[0]},
|
||||||
|
/* const int pad[NDIM] = */ {padding[0]},
|
||||||
|
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
||||||
|
/* const int idil[NDIM] = */ {in_dilation[0]},
|
||||||
|
/* const size_t in_strides[NDIM + 2] = */
|
||||||
|
{in.strides()[0], in.strides()[1], in.strides()[2]},
|
||||||
|
/* const size_t wt_strides[NDIM + 2] = */
|
||||||
|
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
||||||
|
/* const size_t out_strides[NDIM + 2] = */
|
||||||
|
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
||||||
|
/* const int groups = */ groups,
|
||||||
|
/* const bool flip = */ flip};
|
||||||
|
|
||||||
|
// Direct to explicit gemm conv
|
||||||
|
if (groups > 1) {
|
||||||
|
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
} else {
|
||||||
|
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void conv_2D_gpu(
|
void conv_2D_gpu(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@@ -808,57 +887,7 @@ void conv_2D_gpu(
|
|||||||
/* const int groups = */ groups,
|
/* const int groups = */ groups,
|
||||||
/* const bool flip = */ flip,
|
/* const bool flip = */ flip,
|
||||||
};
|
};
|
||||||
|
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||||
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
|
|
||||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
|
||||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
|
||||||
|
|
||||||
if (is_idil_one && groups > 1) {
|
|
||||||
const int C_per_group = conv_params.C / groups;
|
|
||||||
const int O_per_group = conv_params.O / groups;
|
|
||||||
|
|
||||||
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
|
|
||||||
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
|
|
||||||
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
|
|
||||||
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
|
|
||||||
conv_params.wt_strides[1] == conv_params.wS[1] &&
|
|
||||||
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
|
|
||||||
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
|
|
||||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
|
||||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
} else {
|
|
||||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Direct to winograd conv
|
|
||||||
bool inp_large =
|
|
||||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
|
|
||||||
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
|
||||||
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
|
||||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
|
||||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
|
||||||
channels_large) {
|
|
||||||
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Direct to implicit gemm conv
|
|
||||||
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
|
|
||||||
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
|
|
||||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
}
|
|
||||||
|
|
||||||
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
|
|
||||||
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Direct to explicit gemm conv
|
|
||||||
else {
|
|
||||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void conv_3D_gpu(
|
void conv_3D_gpu(
|
||||||
@@ -952,7 +981,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_,
|
padding_lo_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
@@ -967,7 +996,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_,
|
padding_lo_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
@@ -983,12 +1012,13 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_,
|
padding_lo_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
groups_,
|
groups_,
|
||||||
flip_);
|
flip_,
|
||||||
|
copies);
|
||||||
}
|
}
|
||||||
// Throw error
|
// Throw error
|
||||||
else {
|
else {
|
||||||
|
|||||||
@@ -55,10 +55,10 @@ void copy_gpu_inplace(
|
|||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
switch (ctype) {
|
switch (ctype) {
|
||||||
case CopyType::Scalar:
|
case CopyType::Scalar:
|
||||||
kernel_name = (large ? "s2" : "s");
|
kernel_name = large ? "s2" : "s";
|
||||||
break;
|
break;
|
||||||
case CopyType::Vector:
|
case CopyType::Vector:
|
||||||
kernel_name = (large ? "v2" : "v");
|
kernel_name = large ? "v2" : "v";
|
||||||
break;
|
break;
|
||||||
case CopyType::General:
|
case CopyType::General:
|
||||||
kernel_name = "g";
|
kernel_name = "g";
|
||||||
@@ -85,7 +85,10 @@ void copy_gpu_inplace(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
work_per_thread = get_work_per_thread(in.dtype());
|
work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
|
||||||
|
if (work_per_thread > 1) {
|
||||||
|
kernel_name += "n";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
||||||
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
|
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
|
||||||
@@ -170,9 +173,10 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
|||||||
}
|
}
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
bool large = out.data_size() > UINT32_MAX;
|
bool large = out.data_size() > UINT32_MAX;
|
||||||
|
int work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
|
std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s");
|
||||||
type_to_name(val) + type_to_name(out);
|
concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out));
|
||||||
auto kernel = get_copy_kernel(d, kernel_name, val, out);
|
auto kernel = get_copy_kernel(d, kernel_name, val, out);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
@@ -180,7 +184,6 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
|||||||
compute_encoder.set_input_array(val, 0);
|
compute_encoder.set_input_array(val, 0);
|
||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
|
||||||
int work_per_thread = get_work_per_thread(val.dtype());
|
|
||||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
|
|||||||
@@ -1,12 +1,326 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <regex>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/jit/includes.h"
|
#include "mlx/backend/metal/jit/includes.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/fast.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core::fast {
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
|
struct CustomKernelCache {
|
||||||
|
std::unordered_map<std::string, std::string> libraries;
|
||||||
|
};
|
||||||
|
|
||||||
|
static CustomKernelCache& cache() {
|
||||||
|
static CustomKernelCache cache_;
|
||||||
|
return cache_;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string write_signature(
|
||||||
|
std::string func_name,
|
||||||
|
const std::string& header,
|
||||||
|
const std::string& source,
|
||||||
|
const std::vector<std::string>& input_names,
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<std::string>& output_names,
|
||||||
|
const std::vector<Dtype>& output_dtypes,
|
||||||
|
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||||
|
const std::vector<std::string>& attributes,
|
||||||
|
const std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||||
|
bool atomic_outputs) {
|
||||||
|
std::string kernel_source;
|
||||||
|
kernel_source.reserve(header.size() + source.size() + 16384);
|
||||||
|
kernel_source += header;
|
||||||
|
// Auto-generate a function signature based on `template_args`
|
||||||
|
// and the dtype/shape of the arrays passed as `inputs`.
|
||||||
|
if (!template_args.empty()) {
|
||||||
|
kernel_source += "template <";
|
||||||
|
int i = 0;
|
||||||
|
for (const auto& [name, arg] : template_args) {
|
||||||
|
std::string param_type;
|
||||||
|
if (std::holds_alternative<int>(arg)) {
|
||||||
|
param_type = "int";
|
||||||
|
} else if (std::holds_alternative<bool>(arg)) {
|
||||||
|
param_type = "bool";
|
||||||
|
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||||
|
param_type = "typename";
|
||||||
|
}
|
||||||
|
if (i > 0) {
|
||||||
|
kernel_source += ", ";
|
||||||
|
}
|
||||||
|
kernel_source += param_type;
|
||||||
|
kernel_source += " ";
|
||||||
|
kernel_source += name;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
kernel_source += ">\n";
|
||||||
|
}
|
||||||
|
kernel_source += "[[kernel]] void ";
|
||||||
|
kernel_source += func_name;
|
||||||
|
kernel_source += "(\n";
|
||||||
|
|
||||||
|
int index = 0;
|
||||||
|
constexpr int max_constant_array_size = 8;
|
||||||
|
// Add inputs
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& name = input_names[i];
|
||||||
|
const auto& arr = inputs[i];
|
||||||
|
auto dtype = get_type_string(arr.dtype());
|
||||||
|
std::string location =
|
||||||
|
arr.size() < max_constant_array_size ? "constant" : "device";
|
||||||
|
std::string ref = arr.ndim() == 0 ? "&" : "*";
|
||||||
|
kernel_source += " const ";
|
||||||
|
kernel_source += location;
|
||||||
|
kernel_source += " ";
|
||||||
|
kernel_source += dtype;
|
||||||
|
kernel_source += ref;
|
||||||
|
kernel_source += " ";
|
||||||
|
kernel_source += name;
|
||||||
|
kernel_source += " [[buffer(";
|
||||||
|
kernel_source += std::to_string(index);
|
||||||
|
kernel_source += ")]],\n";
|
||||||
|
index++;
|
||||||
|
// Add input shape, strides and ndim if present in the source
|
||||||
|
if (arr.ndim() > 0) {
|
||||||
|
if (shape_infos[i].shape) {
|
||||||
|
kernel_source +=
|
||||||
|
(" const constant int* " + name + "_shape [[buffer(" +
|
||||||
|
std::to_string(index) + ")]],\n");
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
if (shape_infos[i].strides) {
|
||||||
|
kernel_source +=
|
||||||
|
(" const constant int64_t* " + name + "_strides [[buffer(" +
|
||||||
|
std::to_string(index) + ")]],\n");
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
if (shape_infos[i].ndim) {
|
||||||
|
kernel_source +=
|
||||||
|
(" const constant int& " + name + "_ndim [[buffer(" +
|
||||||
|
std::to_string(index) + ")]],\n");
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Add outputs
|
||||||
|
for (int i = 0; i < output_names.size(); ++i) {
|
||||||
|
const auto& name = output_names[i];
|
||||||
|
const auto& dtype = output_dtypes[i];
|
||||||
|
kernel_source += " device ";
|
||||||
|
auto type_string = get_type_string(dtype);
|
||||||
|
if (atomic_outputs) {
|
||||||
|
kernel_source += "atomic<";
|
||||||
|
}
|
||||||
|
kernel_source += type_string;
|
||||||
|
if (atomic_outputs) {
|
||||||
|
kernel_source += ">";
|
||||||
|
}
|
||||||
|
kernel_source += "* ";
|
||||||
|
kernel_source += name;
|
||||||
|
kernel_source += " [[buffer(";
|
||||||
|
kernel_source += std::to_string(index);
|
||||||
|
kernel_source += ")]]";
|
||||||
|
if (index < inputs.size() + output_names.size() - 1 ||
|
||||||
|
attributes.size() > 0) {
|
||||||
|
kernel_source += ",\n";
|
||||||
|
} else {
|
||||||
|
kernel_source += ") {\n";
|
||||||
|
}
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
|
||||||
|
index = 0;
|
||||||
|
for (const auto& attr : attributes) {
|
||||||
|
kernel_source += attr;
|
||||||
|
if (index < attributes.size() - 1) {
|
||||||
|
kernel_source += ",\n";
|
||||||
|
} else {
|
||||||
|
kernel_source += ") {\n";
|
||||||
|
}
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
kernel_source += source;
|
||||||
|
kernel_source += "\n}\n";
|
||||||
|
return kernel_source;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string write_template(
|
||||||
|
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
|
||||||
|
std::ostringstream template_def;
|
||||||
|
template_def << "<";
|
||||||
|
int i = 0;
|
||||||
|
for (const auto& [name, arg] : template_args) {
|
||||||
|
if (i > 0) {
|
||||||
|
template_def << ", ";
|
||||||
|
}
|
||||||
|
if (std::holds_alternative<int>(arg)) {
|
||||||
|
template_def << std::get<int>(arg);
|
||||||
|
} else if (std::holds_alternative<bool>(arg)) {
|
||||||
|
template_def << std::get<bool>(arg);
|
||||||
|
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||||
|
template_def << get_type_string(std::get<Dtype>(arg));
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
template_def << ">";
|
||||||
|
return template_def.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
MetalKernelFunction metal_kernel(
|
||||||
|
const std::string& name,
|
||||||
|
const std::vector<std::string>& input_names,
|
||||||
|
const std::vector<std::string>& output_names,
|
||||||
|
const std::string& source,
|
||||||
|
const std::string& header /* = "" */,
|
||||||
|
bool ensure_row_contiguous /* = true */,
|
||||||
|
bool atomic_outputs /* = false */) {
|
||||||
|
if (output_names.empty()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[metal_kernel] Must specify at least one output.");
|
||||||
|
}
|
||||||
|
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||||
|
for (auto& n : input_names) {
|
||||||
|
CustomKernelShapeInfo shape_info;
|
||||||
|
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||||
|
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||||
|
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||||
|
shape_infos.push_back(shape_info);
|
||||||
|
}
|
||||||
|
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
|
||||||
|
{"dispatch_quadgroups_per_threadgroup", "uint"},
|
||||||
|
{"dispatch_simdgroups_per_threadgroup", "uint"},
|
||||||
|
{"dispatch_threads_per_threadgroup", "uint3"},
|
||||||
|
{"grid_origin", "uint3"},
|
||||||
|
{"grid_size", "uint3"},
|
||||||
|
{"quadgroup_index_in_threadgroup", "uint"},
|
||||||
|
{"quadgroups_per_threadgroup", "uint"},
|
||||||
|
{"simdgroup_index_in_threadgroup", "uint"},
|
||||||
|
{"simdgroups_per_threadgroup", "uint"},
|
||||||
|
{"thread_execution_width", "uint"},
|
||||||
|
{"thread_index_in_quadgroup", "uint"},
|
||||||
|
{"thread_index_in_simdgroup", "uint"},
|
||||||
|
{"thread_index_in_threadgroup", "uint"},
|
||||||
|
{"thread_position_in_grid", "uint3"},
|
||||||
|
{"thread_position_in_threadgroup", "uint3"},
|
||||||
|
{"threadgroup_position_in_grid", "uint3"},
|
||||||
|
{"threadgroups_per_grid", "uint3"},
|
||||||
|
{"threads_per_grid", "uint3"},
|
||||||
|
{"threads_per_simdgroup", "uint"},
|
||||||
|
{"threads_per_threadgroup", "uint3"},
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::string> attributes;
|
||||||
|
for (const auto& [attr, dtype] : metal_attributes) {
|
||||||
|
if (source.find(attr) != std::string::npos) {
|
||||||
|
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return [=,
|
||||||
|
shape_infos = std::move(shape_infos),
|
||||||
|
attributes = std::move(attributes)](
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<Shape>& output_shapes,
|
||||||
|
const std::vector<Dtype>& output_dtypes,
|
||||||
|
std::tuple<int, int, int> grid,
|
||||||
|
std::tuple<int, int, int> threadgroup,
|
||||||
|
const std::vector<std::pair<std::string, TemplateArg>>&
|
||||||
|
template_args = {},
|
||||||
|
std::optional<float> init_value = std::nullopt,
|
||||||
|
bool verbose = false,
|
||||||
|
StreamOrDevice s_ = {}) {
|
||||||
|
if (inputs.size() != input_names.size()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[metal_kernel] Expected `inputs` to have size "
|
||||||
|
<< input_names.size() << " but got size " << inputs.size() << "."
|
||||||
|
<< std::endl;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (output_shapes.size() != output_names.size()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[metal_kernel] Expected `output_shapes` to have size "
|
||||||
|
<< output_names.size() << " but got size " << output_shapes.size()
|
||||||
|
<< "." << std::endl;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (output_dtypes.size() != output_names.size()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[metal_kernel] Expected `output_dtypes` to have size "
|
||||||
|
<< output_names.size() << " but got size " << output_dtypes.size()
|
||||||
|
<< "." << std::endl;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto s = to_stream(s_);
|
||||||
|
if (s.device != Device::gpu) {
|
||||||
|
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string kernel_name = "custom_kernel_" + name;
|
||||||
|
std::string template_def = "";
|
||||||
|
if (!template_args.empty()) {
|
||||||
|
std::regex disallowed_chars("\\<|\\>|(, )");
|
||||||
|
template_def = write_template(template_args);
|
||||||
|
auto template_hash =
|
||||||
|
std::regex_replace(template_def, disallowed_chars, "_");
|
||||||
|
template_hash.pop_back();
|
||||||
|
kernel_name += "_";
|
||||||
|
kernel_name += template_hash;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string kernel_source = write_signature(
|
||||||
|
kernel_name,
|
||||||
|
header,
|
||||||
|
source,
|
||||||
|
input_names,
|
||||||
|
inputs,
|
||||||
|
output_names,
|
||||||
|
output_dtypes,
|
||||||
|
template_args,
|
||||||
|
attributes,
|
||||||
|
shape_infos,
|
||||||
|
atomic_outputs);
|
||||||
|
|
||||||
|
if (!template_args.empty()) {
|
||||||
|
template_def = kernel_name + template_def;
|
||||||
|
kernel_source += "\ntemplate [[host_name(\"";
|
||||||
|
kernel_source += kernel_name;
|
||||||
|
kernel_source += "\")]] [[kernel]] decltype(";
|
||||||
|
kernel_source += template_def;
|
||||||
|
kernel_source += ") ";
|
||||||
|
kernel_source += template_def;
|
||||||
|
kernel_source += ";\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verbose) {
|
||||||
|
std::cout << "Generated source code for `" << name << "`:" << std::endl
|
||||||
|
<< "```" << std::endl
|
||||||
|
<< kernel_source << std::endl
|
||||||
|
<< "```" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
return array::make_arrays(
|
||||||
|
std::move(output_shapes),
|
||||||
|
std::move(output_dtypes),
|
||||||
|
std::make_shared<CustomKernel>(
|
||||||
|
s,
|
||||||
|
std::move(kernel_name),
|
||||||
|
std::move(kernel_source),
|
||||||
|
grid,
|
||||||
|
threadgroup,
|
||||||
|
shape_infos,
|
||||||
|
ensure_row_contiguous,
|
||||||
|
init_value),
|
||||||
|
std::move(inputs));
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
void CustomKernel::eval_gpu(
|
void CustomKernel::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
@@ -39,9 +353,23 @@ void CustomKernel::eval_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
const auto& lib_name = name_;
|
|
||||||
auto lib =
|
{
|
||||||
d.get_library(lib_name, [this] { return metal::utils() + source_; });
|
// Clear kernels from the device library cache if needed
|
||||||
|
auto& kernel_cache = cache();
|
||||||
|
if (auto it = kernel_cache.libraries.find(name_);
|
||||||
|
it != kernel_cache.libraries.end()) {
|
||||||
|
if (it->second != source_) {
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
d.clear_library(name_);
|
||||||
|
it->second = source_;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
kernel_cache.libraries.emplace(name_, source_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto lib = d.get_library(name_, [this] { return metal::utils() + source_; });
|
||||||
auto kernel = d.get_kernel(name_, lib);
|
auto kernel = d.get_kernel(name_, lib);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
@@ -73,6 +401,16 @@ void CustomKernel::eval_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const auto [tx, ty, tz] = threadgroup_;
|
const auto [tx, ty, tz] = threadgroup_;
|
||||||
|
auto tg_size = tx * ty * tz;
|
||||||
|
auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (tg_size > max_tg_size) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Thread group size (" << tg_size << ") is greater than "
|
||||||
|
<< " the maximum allowed threads per threadgroup (" << max_tg_size
|
||||||
|
<< ").";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
const auto [gx, gy, gz] = grid_;
|
const auto [gx, gy, gz] = grid_;
|
||||||
MTL::Size group_dims =
|
MTL::Size group_dims =
|
||||||
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
||||||
|
|||||||
@@ -295,7 +295,7 @@ void CommandEncoder::barrier() {
|
|||||||
Device::Device() {
|
Device::Device() {
|
||||||
auto pool = new_scoped_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
device_ = load_device();
|
device_ = load_device();
|
||||||
library_map_ = {{"mlx", load_default_library(device_)}};
|
default_library_ = load_default_library(device_);
|
||||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||||
auto arch = arch_.back();
|
auto arch = arch_.back();
|
||||||
switch (arch) {
|
switch (arch) {
|
||||||
@@ -326,11 +326,11 @@ Device::Device() {
|
|||||||
|
|
||||||
Device::~Device() {
|
Device::~Device() {
|
||||||
auto pool = new_scoped_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
for (auto& k : kernel_map_) {
|
for (auto& [l, kernel_map] : library_kernels_) {
|
||||||
k.second->release();
|
l->release();
|
||||||
}
|
for (auto& [_, k] : kernel_map) {
|
||||||
for (auto& l : library_map_) {
|
k->release();
|
||||||
l.second->release();
|
}
|
||||||
}
|
}
|
||||||
stream_map_.clear();
|
stream_map_.clear();
|
||||||
device_->release();
|
device_->release();
|
||||||
@@ -474,13 +474,24 @@ CommandEncoder& Device::get_command_encoder(int index) {
|
|||||||
return *stream.encoder;
|
return *stream.encoder;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::register_library(
|
MTL::Library* Device::get_library(
|
||||||
const std::string& lib_name,
|
const std::string& name,
|
||||||
const std::string& lib_path) {
|
const std::string& path /* = "" */) {
|
||||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
{
|
||||||
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
|
std::shared_lock rlock(library_mtx_);
|
||||||
library_map_.insert({lib_name, new_lib});
|
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_lock wlock(library_mtx_);
|
||||||
|
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto new_lib = load_library(device_, name, path.c_str());
|
||||||
|
library_map_.insert({name, new_lib});
|
||||||
|
return new_lib;
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Library* Device::build_library_(const std::string& source_string) {
|
MTL::Library* Device::build_library_(const std::string& source_string) {
|
||||||
@@ -649,6 +660,19 @@ MTL::Library* Device::get_library(
|
|||||||
return mtl_lib;
|
return mtl_lib;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Device::clear_library(const std::string& name) {
|
||||||
|
std::unique_lock wlock(library_mtx_);
|
||||||
|
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||||
|
auto kernel_map_it = library_kernels_.find(it->second);
|
||||||
|
for (auto& [_, kernel] : kernel_map_it->second) {
|
||||||
|
kernel->release();
|
||||||
|
}
|
||||||
|
library_kernels_.erase(kernel_map_it);
|
||||||
|
it->second->release();
|
||||||
|
library_map_.erase(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
MTL::LinkedFunctions* Device::get_linked_functions_(
|
MTL::LinkedFunctions* Device::get_linked_functions_(
|
||||||
const std::vector<MTL::Function*>& funcs) {
|
const std::vector<MTL::Function*>& funcs) {
|
||||||
if (funcs.empty()) {
|
if (funcs.empty()) {
|
||||||
@@ -679,6 +703,7 @@ MTL::ComputePipelineState* Device::get_kernel_(
|
|||||||
std::unique_lock wlock(kernel_mtx_);
|
std::unique_lock wlock(kernel_mtx_);
|
||||||
|
|
||||||
// Try loading again to avoid loading twice
|
// Try loading again to avoid loading twice
|
||||||
|
auto& kernel_map_ = library_kernels_[mtl_lib];
|
||||||
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
|
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
@@ -713,6 +738,7 @@ MTL::ComputePipelineState* Device::get_kernel(
|
|||||||
std::shared_lock lock(kernel_mtx_);
|
std::shared_lock lock(kernel_mtx_);
|
||||||
|
|
||||||
// Look for cached kernel
|
// Look for cached kernel
|
||||||
|
auto& kernel_map_ = library_kernels_[mtl_lib];
|
||||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
@@ -722,23 +748,11 @@ MTL::ComputePipelineState* Device::get_kernel(
|
|||||||
|
|
||||||
MTL::ComputePipelineState* Device::get_kernel(
|
MTL::ComputePipelineState* Device::get_kernel(
|
||||||
const std::string& base_name,
|
const std::string& base_name,
|
||||||
const std::string& lib_name /* = "mlx" */,
|
|
||||||
const std::string& hash_name /* = "" */,
|
const std::string& hash_name /* = "" */,
|
||||||
const MTLFCList& func_consts /* = {} */,
|
const MTLFCList& func_consts /* = {} */,
|
||||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||||
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
|
return get_kernel(
|
||||||
{
|
base_name, default_library_, hash_name, func_consts, linked_functions);
|
||||||
// Multiple readers allowed
|
|
||||||
std::shared_lock lock(kernel_mtx_);
|
|
||||||
|
|
||||||
// Look for cached kernel
|
|
||||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Search for cached metal lib
|
|
||||||
MTL::Library* mtl_lib = get_library_(lib_name);
|
|
||||||
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
|
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
|
||||||
|
|||||||
@@ -95,6 +95,10 @@ struct CommandEncoder {
|
|||||||
return enc_->setBytes(&v, sizeof(T), idx);
|
return enc_->setBytes(&v, sizeof(T), idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_threadgroup_memory_length(size_t length, int idx) {
|
||||||
|
enc_->setThreadgroupMemoryLength(length, idx);
|
||||||
|
}
|
||||||
|
|
||||||
ConcurrentContext start_concurrent() {
|
ConcurrentContext start_concurrent() {
|
||||||
return ConcurrentContext(*this);
|
return ConcurrentContext(*this);
|
||||||
}
|
}
|
||||||
@@ -183,14 +187,16 @@ class Device {
|
|||||||
CommandEncoder& get_command_encoder(int index);
|
CommandEncoder& get_command_encoder(int index);
|
||||||
void end_encoding(int index);
|
void end_encoding(int index);
|
||||||
|
|
||||||
void register_library(
|
MTL::Library* get_library(
|
||||||
const std::string& lib_name,
|
const std::string& name,
|
||||||
const std::string& lib_path = "");
|
const std::string& path = "");
|
||||||
|
|
||||||
MTL::Library* get_library(
|
MTL::Library* get_library(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const std::function<std::string(void)>& builder);
|
const std::function<std::string(void)>& builder);
|
||||||
|
|
||||||
|
void clear_library(const std::string& name);
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_kernel(
|
MTL::ComputePipelineState* get_kernel(
|
||||||
const std::string& base_name,
|
const std::string& base_name,
|
||||||
MTL::Library* mtl_lib,
|
MTL::Library* mtl_lib,
|
||||||
@@ -200,7 +206,6 @@ class Device {
|
|||||||
|
|
||||||
MTL::ComputePipelineState* get_kernel(
|
MTL::ComputePipelineState* get_kernel(
|
||||||
const std::string& base_name,
|
const std::string& base_name,
|
||||||
const std::string& lib_name = "mlx",
|
|
||||||
const std::string& hash_name = "",
|
const std::string& hash_name = "",
|
||||||
const MTLFCList& func_consts = {},
|
const MTLFCList& func_consts = {},
|
||||||
const std::vector<MTL::Function*>& linked_functions = {});
|
const std::vector<MTL::Function*>& linked_functions = {});
|
||||||
@@ -254,10 +259,13 @@ class Device {
|
|||||||
std::unordered_map<int32_t, DeviceStream> stream_map_;
|
std::unordered_map<int32_t, DeviceStream> stream_map_;
|
||||||
|
|
||||||
std::shared_mutex kernel_mtx_;
|
std::shared_mutex kernel_mtx_;
|
||||||
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
|
||||||
|
|
||||||
std::shared_mutex library_mtx_;
|
std::shared_mutex library_mtx_;
|
||||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||||
|
MTL::Library* default_library_;
|
||||||
|
std::unordered_map<
|
||||||
|
MTL::Library*,
|
||||||
|
std::unordered_map<std::string, MTL::ComputePipelineState*>>
|
||||||
|
library_kernels_;
|
||||||
const MTL::ResidencySet* residency_set_{nullptr};
|
const MTL::ResidencySet* residency_set_{nullptr};
|
||||||
std::string arch_;
|
std::string arch_;
|
||||||
int max_ops_per_buffer_;
|
int max_ops_per_buffer_;
|
||||||
|
|||||||
@@ -632,7 +632,7 @@ void fft_op(
|
|||||||
func_consts.push_back(make_int(&rader_m, 3));
|
func_consts.push_back(make_int(&rader_m, 3));
|
||||||
|
|
||||||
// The overall number of FFTs we're going to compute for this input
|
// The overall number of FFTs we're going to compute for this input
|
||||||
int size = out.dtype() == float32 ? out.size() : in.size();
|
size_t size = out.dtype() == float32 ? out.size() : in.size();
|
||||||
if (real && inverse && four_step_params.required) {
|
if (real && inverse && four_step_params.required) {
|
||||||
size = out.size();
|
size = out.size();
|
||||||
}
|
}
|
||||||
@@ -659,8 +659,6 @@ void fft_op(
|
|||||||
// We can perform 2 RFFTs at once so the batch size is halved.
|
// We can perform 2 RFFTs at once so the batch size is halved.
|
||||||
batch_size = (batch_size + 2 - 1) / 2;
|
batch_size = (batch_size + 2 - 1) / 2;
|
||||||
}
|
}
|
||||||
int out_buffer_size = out.size();
|
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
|
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
|
||||||
auto out_type_str = out.dtype() == float32 ? "float" : "float2";
|
auto out_type_str = out.dtype() == float32 ? "float" : "float2";
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/jit/includes.h"
|
#include "mlx/backend/metal/jit/includes.h"
|
||||||
@@ -458,17 +459,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.set_output_array(out, 2);
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
|
||||||
// Set source info
|
// Set source info
|
||||||
auto shape = idx.shape();
|
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
||||||
shape.erase(shape.begin() + axis_);
|
compute_encoder.set_vector_bytes(remove_index(src.strides(), axis_), 4);
|
||||||
compute_encoder.set_vector_bytes(shape, 3);
|
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
||||||
|
|
||||||
auto strides = src.strides();
|
|
||||||
strides.erase(strides.begin() + axis_);
|
|
||||||
compute_encoder.set_vector_bytes(strides, 4);
|
|
||||||
|
|
||||||
strides = idx.strides();
|
|
||||||
strides.erase(strides.begin() + axis_);
|
|
||||||
compute_encoder.set_vector_bytes(strides, 5);
|
|
||||||
compute_encoder.set_bytes(ndim - 1, 6);
|
compute_encoder.set_bytes(ndim - 1, 6);
|
||||||
compute_encoder.set_bytes(axis_, 7);
|
compute_encoder.set_bytes(axis_, 7);
|
||||||
compute_encoder.set_bytes(src.shape(axis_), 8);
|
compute_encoder.set_bytes(src.shape(axis_), 8);
|
||||||
@@ -582,17 +575,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.set_output_array(out, 2);
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
|
||||||
// Set source info
|
// Set source info
|
||||||
auto shape = idx.shape();
|
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
||||||
shape.erase(shape.begin() + axis_);
|
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
|
||||||
compute_encoder.set_vector_bytes(shape, 3);
|
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
||||||
|
|
||||||
auto strides = upd.strides();
|
|
||||||
strides.erase(strides.begin() + axis_);
|
|
||||||
compute_encoder.set_vector_bytes(strides, 4);
|
|
||||||
|
|
||||||
strides = idx.strides();
|
|
||||||
strides.erase(strides.begin() + axis_);
|
|
||||||
compute_encoder.set_vector_bytes(strides, 5);
|
|
||||||
compute_encoder.set_bytes(ndim - 1, 6);
|
compute_encoder.set_bytes(ndim - 1, 6);
|
||||||
compute_encoder.set_bytes(axis_, 7);
|
compute_encoder.set_bytes(axis_, 7);
|
||||||
compute_encoder.set_bytes(out.shape(axis_), 8);
|
compute_encoder.set_bytes(out.shape(axis_), 8);
|
||||||
|
|||||||
@@ -41,7 +41,11 @@ MTL::ComputePipelineState* get_unary_kernel(
|
|||||||
std::string kernel_source = metal::utils();
|
std::string kernel_source = metal::utils();
|
||||||
concatenate(kernel_source, metal::unary_ops(), metal::unary());
|
concatenate(kernel_source, metal::unary_ops(), metal::unary());
|
||||||
kernel_source +=
|
kernel_source +=
|
||||||
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
|
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1);
|
||||||
|
if (get_work_per_thread(in_type) > 1) {
|
||||||
|
kernel_source +=
|
||||||
|
get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op);
|
||||||
|
}
|
||||||
kernel_source +=
|
kernel_source +=
|
||||||
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||||
kernel_source += get_template_definition(
|
kernel_source += get_template_definition(
|
||||||
@@ -59,11 +63,8 @@ void append_binary_kernels(
|
|||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op,
|
const std::string op,
|
||||||
std::string& kernel_source) {
|
std::string& kernel_source) {
|
||||||
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
|
const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
|
||||||
{"ss", "binary_ss"},
|
{"ss", "binary_ss"},
|
||||||
{"vs", "binary_vs"},
|
|
||||||
{"sv", "binary_sv"},
|
|
||||||
{"vv", "binary_vv"},
|
|
||||||
{"vs2", "binary_vs2"},
|
{"vs2", "binary_vs2"},
|
||||||
{"sv2", "binary_sv2"},
|
{"sv2", "binary_sv2"},
|
||||||
{"vv2", "binary_vv2"},
|
{"vv2", "binary_vv2"},
|
||||||
@@ -78,6 +79,22 @@ void append_binary_kernels(
|
|||||||
kernel_source +=
|
kernel_source +=
|
||||||
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
|
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
|
||||||
}
|
}
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"vs_" + lib_name, "binary_vs", in_t, out_t, op, 1);
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"sv_" + lib_name, "binary_sv", in_t, out_t, op, 1);
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"vv_" + lib_name, "binary_vv", in_t, out_t, op, 1);
|
||||||
|
|
||||||
|
if (get_work_per_thread(in_type) > 1) {
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"vsn_" + lib_name, "binary_vs", in_t, out_t, op);
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"svn_" + lib_name, "binary_sv", in_t, out_t, op);
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"vvn_" + lib_name, "binary_vv", in_t, out_t, op);
|
||||||
|
}
|
||||||
|
|
||||||
kernel_source += get_template_definition(
|
kernel_source += get_template_definition(
|
||||||
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
|
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
|
||||||
kernel_source += get_template_definition(
|
kernel_source += get_template_definition(
|
||||||
@@ -133,8 +150,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
|||||||
auto t_str = get_type_string(type);
|
auto t_str = get_type_string(type);
|
||||||
std::string kernel_source = metal::utils();
|
std::string kernel_source = metal::utils();
|
||||||
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
|
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
|
||||||
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{
|
||||||
{"v", "ternary_v"},
|
|
||||||
{"v2", "ternary_v2"},
|
{"v2", "ternary_v2"},
|
||||||
{"g1large", "ternary_g_nd1"},
|
{"g1large", "ternary_g_nd1"},
|
||||||
{"g2large", "ternary_g_nd2"},
|
{"g2large", "ternary_g_nd2"},
|
||||||
@@ -144,6 +160,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
|||||||
kernel_source +=
|
kernel_source +=
|
||||||
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
||||||
}
|
}
|
||||||
|
if (get_work_per_thread(type) > 1) {
|
||||||
|
kernel_source +=
|
||||||
|
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel_source +=
|
||||||
|
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
|
||||||
kernel_source += get_template_definition(
|
kernel_source += get_template_definition(
|
||||||
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
|
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
|
||||||
kernel_source += get_template_definition(
|
kernel_source += get_template_definition(
|
||||||
@@ -170,15 +193,22 @@ MTL::ComputePipelineState* get_copy_kernel(
|
|||||||
kernel_source += metal::copy();
|
kernel_source += metal::copy();
|
||||||
auto in_type = get_type_string(in.dtype());
|
auto in_type = get_type_string(in.dtype());
|
||||||
auto out_type = get_type_string(out.dtype());
|
auto out_type = get_type_string(out.dtype());
|
||||||
kernel_source +=
|
kernel_source += get_template_definition(
|
||||||
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
|
"s_" + lib_name, "copy_s", in_type, out_type, 1);
|
||||||
kernel_source +=
|
kernel_source +=
|
||||||
get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type);
|
get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type);
|
||||||
kernel_source +=
|
kernel_source += get_template_definition(
|
||||||
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
|
"v_" + lib_name, "copy_v", in_type, out_type, 1);
|
||||||
kernel_source +=
|
kernel_source +=
|
||||||
get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type);
|
get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type);
|
||||||
|
|
||||||
|
if (get_work_per_thread(out.dtype()) > 1) {
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"sn_" + lib_name, "copy_s", in_type, out_type);
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"vn_" + lib_name, "copy_v", in_type, out_type);
|
||||||
|
}
|
||||||
|
|
||||||
kernel_source += get_template_definition(
|
kernel_source += get_template_definition(
|
||||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
|
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
|
||||||
kernel_source += get_template_definition(
|
kernel_source += get_template_definition(
|
||||||
@@ -697,6 +727,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
|||||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
const array& out,
|
const array& out,
|
||||||
int bm,
|
int bm,
|
||||||
int bn,
|
int bn,
|
||||||
@@ -719,7 +751,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
|||||||
wn);
|
wn);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_fft_kernel(
|
MTL::ComputePipelineState* get_fft_kernel(
|
||||||
|
|||||||
@@ -205,6 +205,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
|
|||||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
const array& out,
|
const array& out,
|
||||||
int bm,
|
int bm,
|
||||||
int bn,
|
int bn,
|
||||||
|
|||||||
@@ -80,9 +80,10 @@ template <typename T, typename Op, int N_READS = 4>
|
|||||||
const constant size_t& ndim [[buffer(5)]],
|
const constant size_t& ndim [[buffer(5)]],
|
||||||
const constant int64_t& axis_stride [[buffer(6)]],
|
const constant int64_t& axis_stride [[buffer(6)]],
|
||||||
const constant size_t& axis_size [[buffer(7)]],
|
const constant size_t& axis_size [[buffer(7)]],
|
||||||
uint gid [[thread_position_in_grid]],
|
uint3 gid [[thread_position_in_grid]],
|
||||||
uint lid [[thread_position_in_threadgroup]],
|
uint3 gsize [[threads_per_grid]],
|
||||||
uint lsize [[threads_per_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint3 lsize [[threads_per_threadgroup]],
|
||||||
uint simd_size [[threads_per_simdgroup]],
|
uint simd_size [[threads_per_simdgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
@@ -104,17 +105,18 @@ template <typename T, typename Op, int N_READS = 4>
|
|||||||
|
|
||||||
// Compute the input/output index. There is one beginning and one output for
|
// Compute the input/output index. There is one beginning and one output for
|
||||||
// the whole threadgroup.
|
// the whole threadgroup.
|
||||||
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
|
int64_t row_idx = gid.y + static_cast<int64_t>(gsize.y) * gid.z;
|
||||||
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
|
auto in_idx = elem_to_loc(row_idx, shape, in_strides, ndim);
|
||||||
|
auto out_idx = elem_to_loc(row_idx, shape, out_strides, ndim);
|
||||||
|
|
||||||
IndexValPair<T> best{0, Op::init};
|
IndexValPair<T> best{0, Op::init};
|
||||||
|
|
||||||
threadgroup IndexValPair<T> local_data[32];
|
threadgroup IndexValPair<T> local_data[32];
|
||||||
|
|
||||||
// Loop over the reduction axis in lsize*N_READS buckets
|
// Loop over the reduction axis in lsize*N_READS buckets
|
||||||
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
|
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) {
|
||||||
// Read the current value
|
// Read the current value
|
||||||
uint32_t current_index = r * lsize * N_READS + lid * N_READS;
|
uint32_t current_index = r * lsize.x * N_READS + lid.x * N_READS;
|
||||||
uint32_t offset = current_index;
|
uint32_t offset = current_index;
|
||||||
const device T* current_in = in + in_idx + current_index * axis_stride;
|
const device T* current_in = in + in_idx + current_index * axis_stride;
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
@@ -144,7 +146,7 @@ template <typename T, typename Op, int N_READS = 4>
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read the appropriate value from local data and perform one simd reduction
|
// Read the appropriate value from local data and perform one simd reduction
|
||||||
uint simd_groups = ceildiv(lsize, simd_size);
|
uint simd_groups = ceildiv(lsize.x, simd_size);
|
||||||
if (simd_lane_id < simd_groups) {
|
if (simd_lane_id < simd_groups) {
|
||||||
best = local_data[simd_lane_id];
|
best = local_data[simd_lane_id];
|
||||||
}
|
}
|
||||||
@@ -154,7 +156,7 @@ template <typename T, typename Op, int N_READS = 4>
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Finally write the output
|
// Finally write the output
|
||||||
if (lid == 0) {
|
if (lid.x == 0) {
|
||||||
out[out_idx] = best.index;
|
out[out_idx] = best.index;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,8 +17,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|||||||
constant uint& size,
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
index *= N;
|
index *= N;
|
||||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
if (N > 1 && index + N > size) {
|
||||||
c[index + i] = Op()(a[0], b[index + i]);
|
for (int i = 0; index + i < size; ++i) {
|
||||||
|
c[index + i] = Op()(a[0], b[index + i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
c[index + i] = Op()(a[0], b[index + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -30,8 +36,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|||||||
constant uint& size,
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
index *= N;
|
index *= N;
|
||||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
if (N > 1 && index + N > size) {
|
||||||
c[index + i] = Op()(a[index + i], b[0]);
|
for (int i = 0; index + i < size; ++i) {
|
||||||
|
c[index + i] = Op()(a[index + i], b[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
c[index + i] = Op()(a[index + i], b[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,8 +55,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|||||||
constant uint& size,
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
index *= N;
|
index *= N;
|
||||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
if (N > 1 && index + N > size) {
|
||||||
c[index + i] = Op()(a[index + i], b[index + i]);
|
for (int i = 0; index + i < size; ++i) {
|
||||||
|
c[index + i] = Op()(a[index + i], b[index + i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
c[index + i] = Op()(a[index + i], b[index + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,8 +75,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
if (N > 1 && offset + N > size) {
|
||||||
c[offset + i] = Op()(a[0], b[offset + i]);
|
for (int i = 0; offset + i < size; ++i) {
|
||||||
|
c[offset + i] = Op()(a[0], b[offset + i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
c[offset + i] = Op()(a[0], b[offset + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,8 +95,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
if (N > 1 && offset + N > size) {
|
||||||
c[offset + i] = Op()(a[offset + i], b[0]);
|
for (int i = 0; offset + i < size; ++i) {
|
||||||
|
c[offset + i] = Op()(a[offset + i], b[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
c[offset + i] = Op()(a[offset + i], b[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,8 +115,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
if (N > 1 && offset + N > size) {
|
||||||
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
for (int i = 0; offset + i < size; ++i) {
|
||||||
|
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,11 +9,16 @@
|
|||||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||||
#include "mlx/backend/metal/kernels/binary.h"
|
#include "mlx/backend/metal/kernels/binary.h"
|
||||||
|
|
||||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
#define instantiate_binary_work_per_thread(op, tname, itype, otype) \
|
||||||
|
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
|
||||||
|
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
|
||||||
|
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) \
|
||||||
|
|
||||||
|
#define instantiate_binary_base(op, tname, itype, otype) \
|
||||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
|
||||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
|
||||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
|
||||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||||
@@ -26,15 +31,19 @@
|
|||||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
|
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
|
||||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||||
|
|
||||||
#define instantiate_binary_integer(op) \
|
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
instantiate_binary_base(op, tname, itype, otype) \
|
||||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
instantiate_binary_work_per_thread(op, tname, itype, otype)
|
||||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
|
||||||
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
|
#define instantiate_binary_integer(op) \
|
||||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||||
instantiate_binary_all(op, int64, int64_t, int64_t)
|
instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
|
||||||
|
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||||
|
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||||
|
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||||
|
instantiate_binary_base(op, int64, int64_t, int64_t)
|
||||||
|
|
||||||
#define instantiate_binary_float(op) \
|
#define instantiate_binary_float(op) \
|
||||||
instantiate_binary_all(op, float16, half, half) \
|
instantiate_binary_all(op, float16, half, half) \
|
||||||
@@ -44,7 +53,7 @@
|
|||||||
#define instantiate_binary_types(op) \
|
#define instantiate_binary_types(op) \
|
||||||
instantiate_binary_all(op, bool_, bool, bool) \
|
instantiate_binary_all(op, bool_, bool, bool) \
|
||||||
instantiate_binary_integer(op) \
|
instantiate_binary_integer(op) \
|
||||||
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
|
instantiate_binary_base(op, complex64, complex64_t, complex64_t)\
|
||||||
instantiate_binary_float(op)
|
instantiate_binary_float(op)
|
||||||
|
|
||||||
#define instantiate_binary_types_bool(op) \
|
#define instantiate_binary_types_bool(op) \
|
||||||
@@ -52,15 +61,15 @@
|
|||||||
instantiate_binary_all(op, uint8, uint8_t, bool) \
|
instantiate_binary_all(op, uint8, uint8_t, bool) \
|
||||||
instantiate_binary_all(op, uint16, uint16_t, bool) \
|
instantiate_binary_all(op, uint16, uint16_t, bool) \
|
||||||
instantiate_binary_all(op, uint32, uint32_t, bool) \
|
instantiate_binary_all(op, uint32, uint32_t, bool) \
|
||||||
instantiate_binary_all(op, uint64, uint64_t, bool) \
|
instantiate_binary_base(op, uint64, uint64_t, bool) \
|
||||||
instantiate_binary_all(op, int8, int8_t, bool) \
|
instantiate_binary_all(op, int8, int8_t, bool) \
|
||||||
instantiate_binary_all(op, int16, int16_t, bool) \
|
instantiate_binary_all(op, int16, int16_t, bool) \
|
||||||
instantiate_binary_all(op, int32, int32_t, bool) \
|
instantiate_binary_all(op, int32, int32_t, bool) \
|
||||||
instantiate_binary_all(op, int64, int64_t, bool) \
|
instantiate_binary_base(op, int64, int64_t, bool) \
|
||||||
instantiate_binary_all(op, float16, half, bool) \
|
instantiate_binary_all(op, float16, half, bool) \
|
||||||
instantiate_binary_all(op, float32, float, bool) \
|
instantiate_binary_all(op, float32, float, bool) \
|
||||||
instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \
|
instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \
|
||||||
instantiate_binary_all(op, complex64, complex64_t, bool)
|
instantiate_binary_base(op, complex64, complex64_t, bool)
|
||||||
|
|
||||||
instantiate_binary_types(Add)
|
instantiate_binary_types(Add)
|
||||||
instantiate_binary_types(Divide)
|
instantiate_binary_types(Divide)
|
||||||
@@ -71,7 +80,7 @@ instantiate_binary_types_bool(Less)
|
|||||||
instantiate_binary_types_bool(LessEqual)
|
instantiate_binary_types_bool(LessEqual)
|
||||||
instantiate_binary_types_bool(NotEqual)
|
instantiate_binary_types_bool(NotEqual)
|
||||||
instantiate_binary_float(LogAddExp)
|
instantiate_binary_float(LogAddExp)
|
||||||
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
|
instantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t)
|
||||||
instantiate_binary_types(Maximum)
|
instantiate_binary_types(Maximum)
|
||||||
instantiate_binary_types(Minimum)
|
instantiate_binary_types(Minimum)
|
||||||
instantiate_binary_types(Multiply)
|
instantiate_binary_types(Multiply)
|
||||||
@@ -84,7 +93,7 @@ instantiate_binary_float(ArcTan2)
|
|||||||
instantiate_binary_all(NaNEqual, float16, half, bool)
|
instantiate_binary_all(NaNEqual, float16, half, bool)
|
||||||
instantiate_binary_all(NaNEqual, float32, float, bool)
|
instantiate_binary_all(NaNEqual, float32, float, bool)
|
||||||
instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)
|
instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)
|
||||||
instantiate_binary_all(NaNEqual, complex64, complex64_t, bool)
|
instantiate_binary_base(NaNEqual, complex64, complex64_t, bool)
|
||||||
|
|
||||||
instantiate_binary_all(LogicalOr, bool_, bool, bool)
|
instantiate_binary_all(LogicalOr, bool_, bool, bool)
|
||||||
instantiate_binary_all(LogicalAnd, bool_, bool, bool)
|
instantiate_binary_all(LogicalAnd, bool_, bool, bool)
|
||||||
|
|||||||
@@ -21,10 +21,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|||||||
constant uint& size,
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
index *= N;
|
index *= N;
|
||||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
if (N > 1 && index + N > size) {
|
||||||
auto out = Op()(a[0], b[index + i]);
|
for (int i = 0; index + i < size; ++i) {
|
||||||
c[index + i] = out[0];
|
auto out = Op()(a[0], b[index + i]);
|
||||||
d[index + i] = out[1];
|
c[index + i] = out[0];
|
||||||
|
d[index + i] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
auto out = Op()(a[0], b[index + i]);
|
||||||
|
c[index + i] = out[0];
|
||||||
|
d[index + i] = out[1];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,10 +45,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|||||||
constant uint& size,
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
index *= N;
|
index *= N;
|
||||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
if (N > 1 && index + N > size) {
|
||||||
auto out = Op()(a[index + i], b[0]);
|
for (int i = 0; index + i < size; ++i) {
|
||||||
c[index + i] = out[0];
|
auto out = Op()(a[index + i], b[0]);
|
||||||
d[index + i] = out[1];
|
c[index + i] = out[0];
|
||||||
|
d[index + i] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
auto out = Op()(a[index + i], b[0]);
|
||||||
|
c[index + i] = out[0];
|
||||||
|
d[index + i] = out[1];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,10 +69,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|||||||
constant uint& size,
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
index *= N;
|
index *= N;
|
||||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
if (N > 1 && index + N > size) {
|
||||||
auto out = Op()(a[index + i], b[index + i]);
|
for (int i = 0; index + i < size; ++i) {
|
||||||
c[index + i] = out[0];
|
auto out = Op()(a[index + i], b[index + i]);
|
||||||
d[index + i] = out[1];
|
c[index + i] = out[0];
|
||||||
|
d[index + i] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
auto out = Op()(a[index + i], b[index + i]);
|
||||||
|
c[index + i] = out[0];
|
||||||
|
d[index + i] = out[1];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,11 +93,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|||||||
constant int64_t& size,
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
if (N > 1 && offset + N > size) {
|
||||||
auto out = Op()(a[0], b[offset + i]);
|
for (int i = 0; offset + i < size; ++i) {
|
||||||
c[offset + i] = out[0];
|
auto out = Op()(a[0], b[offset + i]);
|
||||||
d[offset + i] = out[1];
|
c[offset + i] = out[0];
|
||||||
|
d[offset + i] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
auto out = Op()(a[0], b[offset + i]);
|
||||||
|
c[offset + i] = out[0];
|
||||||
|
d[offset + i] = out[1];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,11 +118,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|||||||
constant int64_t& size,
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
if (N > 1 && offset + N > size) {
|
||||||
auto out = Op()(a[offset + i], b[0]);
|
for (int i = 0; offset + i < size; ++i) {
|
||||||
c[offset + i] = out[0];
|
auto out = Op()(a[offset + i], b[0]);
|
||||||
d[offset + i] = out[1];
|
c[offset + i] = out[0];
|
||||||
|
d[offset + i] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
auto out = Op()(a[offset + i], b[0]);
|
||||||
|
c[offset + i] = out[0];
|
||||||
|
d[offset + i] = out[1];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,11 +143,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|||||||
constant int64_t& size,
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
if (N > 1 && offset + N > size) {
|
||||||
auto out = Op()(a[offset + i], b[offset + i]);
|
for (int i = 0; offset + i < size; ++i) {
|
||||||
c[offset + i] = out[0];
|
auto out = Op()(a[offset + i], b[offset + i]);
|
||||||
d[offset + i] = out[1];
|
c[offset + i] = out[0];
|
||||||
|
d[offset + i] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
auto out = Op()(a[offset + i], b[offset + i]);
|
||||||
|
c[offset + i] = out[0];
|
||||||
|
d[offset + i] = out[1];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,11 +7,16 @@
|
|||||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||||
#include "mlx/backend/metal/kernels/binary_two.h"
|
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||||
|
|
||||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
#define instantiate_binary_work_per_thread(op, tname, itype, otype) \
|
||||||
|
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
|
||||||
|
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
|
||||||
|
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op)
|
||||||
|
|
||||||
|
#define instantiate_binary_base(op, tname, itype, otype) \
|
||||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
|
||||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
|
||||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
|
||||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||||
@@ -24,22 +29,26 @@
|
|||||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||||
|
|
||||||
|
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||||
|
instantiate_binary_base(op, tname, itype, otype) \
|
||||||
|
instantiate_binary_work_per_thread(op, tname, itype, otype)
|
||||||
|
|
||||||
#define instantiate_binary_float(op) \
|
#define instantiate_binary_float(op) \
|
||||||
instantiate_binary_all(op, float16, half, half) \
|
instantiate_binary_all(op, float16, half, half) \
|
||||||
instantiate_binary_all(op, float32, float, float) \
|
instantiate_binary_all(op, float32, float, float) \
|
||||||
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
|
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
|
||||||
|
|
||||||
#define instantiate_binary_types(op) \
|
#define instantiate_binary_types(op) \
|
||||||
instantiate_binary_all(op, bool_, bool, bool) \
|
instantiate_binary_all(op, bool_, bool, bool) \
|
||||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||||
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
|
instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
|
||||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||||
instantiate_binary_all(op, int64, int64_t, int64_t) \
|
instantiate_binary_base(op, int64, int64_t, int64_t) \
|
||||||
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
|
instantiate_binary_base(op, complex64, complex64_t, complex64_t) \
|
||||||
instantiate_binary_float(op)
|
instantiate_binary_float(op)
|
||||||
|
|
||||||
instantiate_binary_types(DivMod) // clang-format on
|
instantiate_binary_types(DivMod) // clang-format on
|
||||||
|
|||||||
@@ -1,52 +1,76 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
||||||
[[kernel]] void copy_s(
|
[[kernel]] void copy_s(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
constant uint& size,
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
index *= N;
|
index *= N;
|
||||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
if (N > 1 && index + N > size) {
|
||||||
dst[index + i] = static_cast<U>(src[0]);
|
for (int i = 0; index + i < size; ++i) {
|
||||||
|
dst[index + i] = static_cast<U>(src[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
dst[index + i] = static_cast<U>(src[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
||||||
[[kernel]] void copy_v(
|
[[kernel]] void copy_v(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
constant uint& size,
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
index *= N;
|
index *= N;
|
||||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
if (N > 1 && index + N > size) {
|
||||||
dst[index + i] = static_cast<U>(src[index + i]);
|
for (int i = 0; index + i < size; ++i) {
|
||||||
|
dst[index + i] = static_cast<U>(src[index + i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
dst[index + i] = static_cast<U>(src[index + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
||||||
[[kernel]] void copy_s2(
|
[[kernel]] void copy_s2(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
constant int64_t& size,
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
if (N > 1 && offset + N > size) {
|
||||||
dst[offset + i] = static_cast<U>(src[0]);
|
for (int i = 0; offset + i < size; ++i) {
|
||||||
|
dst[offset + i] = static_cast<U>(src[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
dst[offset + i] = static_cast<U>(src[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
||||||
[[kernel]] void copy_v2(
|
[[kernel]] void copy_v2(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
constant int64_t& size,
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
if (N > 1 && offset + N > size) {
|
||||||
dst[offset + i] = static_cast<U>(src[offset + i]);
|
for (int i = 0; offset + i < size; ++i) {
|
||||||
|
dst[offset + i] = static_cast<U>(src[offset + i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
dst[offset + i] = static_cast<U>(src[offset + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,9 +4,13 @@
|
|||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
#include "mlx/backend/metal/kernels/copy.h"
|
#include "mlx/backend/metal/kernels/copy.h"
|
||||||
|
|
||||||
#define instantiate_copy_all(tname, itype, otype) \
|
#define instantiate_copy_work_per_thread(tname, itype, otype) \
|
||||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
instantiate_kernel("sn_copy" #tname, copy_s, itype, otype) \
|
||||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
instantiate_kernel("vn_copy" #tname, copy_v, itype, otype)
|
||||||
|
|
||||||
|
#define instantiate_copy_base(tname, itype, otype) \
|
||||||
|
instantiate_kernel("s_copy" #tname, copy_s, itype, otype, 1) \
|
||||||
|
instantiate_kernel("v_copy" #tname, copy_v, itype, otype, 1) \
|
||||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
|
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
|
||||||
@@ -18,6 +22,10 @@
|
|||||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
||||||
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4)
|
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4)
|
||||||
|
|
||||||
|
#define instantiate_copy_all(tname, itype, otype) \
|
||||||
|
instantiate_copy_base(tname, itype, otype) \
|
||||||
|
instantiate_copy_work_per_thread(tname, itype, otype)
|
||||||
|
|
||||||
#define instantiate_copy_same(tname, type) \
|
#define instantiate_copy_same(tname, type) \
|
||||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
|
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
|
||||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
|
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
|
||||||
@@ -42,15 +50,15 @@
|
|||||||
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||||
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
|
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
|
||||||
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
|
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
|
||||||
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
|
instantiate_copy_base(itname ##uint64, itype, uint64_t) \
|
||||||
instantiate_copy_all(itname ##int8, itype, int8_t) \
|
instantiate_copy_all(itname ##int8, itype, int8_t) \
|
||||||
instantiate_copy_all(itname ##int16, itype, int16_t) \
|
instantiate_copy_all(itname ##int16, itype, int16_t) \
|
||||||
instantiate_copy_all(itname ##int32, itype, int32_t) \
|
instantiate_copy_all(itname ##int32, itype, int32_t) \
|
||||||
instantiate_copy_all(itname ##int64, itype, int64_t) \
|
instantiate_copy_base(itname ##int64, itype, int64_t) \
|
||||||
instantiate_copy_all(itname ##float16, itype, half) \
|
instantiate_copy_all(itname ##float16, itype, half) \
|
||||||
instantiate_copy_all(itname ##float32, itype, float) \
|
instantiate_copy_all(itname ##float32, itype, float) \
|
||||||
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
|
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
|
||||||
instantiate_copy_all(itname ##complex64, itype, complex64_t)
|
instantiate_copy_base(itname ##complex64, itype, complex64_t)
|
||||||
|
|
||||||
instantiate_copy_itype(bool_, bool)
|
instantiate_copy_itype(bool_, bool)
|
||||||
instantiate_copy_itype(uint8, uint8_t)
|
instantiate_copy_itype(uint8, uint8_t)
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ struct ReadWriter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
METAL_FUNC void load() const {
|
METAL_FUNC void load() const {
|
||||||
int batch_idx = elem.x * grid.y * n;
|
size_t batch_idx = size_t(elem.x * grid.y) * n;
|
||||||
short tg_idx = elem.y * grid.z + elem.z;
|
short tg_idx = elem.y * grid.z + elem.z;
|
||||||
short max_index = grid.y * n - 2;
|
short max_index = grid.y * n - 2;
|
||||||
|
|
||||||
@@ -121,7 +121,7 @@ struct ReadWriter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
METAL_FUNC void write() const {
|
METAL_FUNC void write() const {
|
||||||
int batch_idx = elem.x * grid.y * n;
|
size_t batch_idx = size_t(elem.x * grid.y) * n;
|
||||||
short tg_idx = elem.y * grid.z + elem.z;
|
short tg_idx = elem.y * grid.z + elem.z;
|
||||||
short max_index = grid.y * n - 2;
|
short max_index = grid.y * n - 2;
|
||||||
|
|
||||||
@@ -144,7 +144,7 @@ struct ReadWriter {
|
|||||||
|
|
||||||
// Padded IO for Bluestein's algorithm
|
// Padded IO for Bluestein's algorithm
|
||||||
METAL_FUNC void load_padded(int length, const device float2* w_k) const {
|
METAL_FUNC void load_padded(int length, const device float2* w_k) const {
|
||||||
int batch_idx = elem.x * grid.y * length + elem.y * length;
|
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
|
||||||
int fft_idx = elem.z;
|
int fft_idx = elem.z;
|
||||||
int m = grid.z;
|
int m = grid.z;
|
||||||
|
|
||||||
@@ -161,7 +161,7 @@ struct ReadWriter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
METAL_FUNC void write_padded(int length, const device float2* w_k) const {
|
METAL_FUNC void write_padded(int length, const device float2* w_k) const {
|
||||||
int batch_idx = elem.x * grid.y * length + elem.y * length;
|
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
|
||||||
int fft_idx = elem.z;
|
int fft_idx = elem.z;
|
||||||
int m = grid.z;
|
int m = grid.z;
|
||||||
float2 inv_factor = {1.0f / n, -1.0f / n};
|
float2 inv_factor = {1.0f / n, -1.0f / n};
|
||||||
@@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
METAL_FUNC void ReadWriter<float, float2>::load() const {
|
METAL_FUNC void ReadWriter<float, float2>::load() const {
|
||||||
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
|
size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2;
|
||||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||||
|
|
||||||
// No out of bounds accesses on odd batch sizes
|
// No out of bounds accesses on odd batch sizes
|
||||||
@@ -283,7 +283,8 @@ template <>
|
|||||||
METAL_FUNC void ReadWriter<float, float2>::write() const {
|
METAL_FUNC void ReadWriter<float, float2>::write() const {
|
||||||
short n_over_2 = (n / 2) + 1;
|
short n_over_2 = (n / 2) + 1;
|
||||||
|
|
||||||
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
size_t batch_idx =
|
||||||
|
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
||||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||||
|
|
||||||
int grid_index = elem.x * grid.y + elem.y;
|
int grid_index = elem.x * grid.y + elem.y;
|
||||||
@@ -317,7 +318,7 @@ template <>
|
|||||||
METAL_FUNC void ReadWriter<float, float2>::load_padded(
|
METAL_FUNC void ReadWriter<float, float2>::load_padded(
|
||||||
int length,
|
int length,
|
||||||
const device float2* w_k) const {
|
const device float2* w_k) const {
|
||||||
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
|
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
|
||||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||||
|
|
||||||
// No out of bounds accesses on odd batch sizes
|
// No out of bounds accesses on odd batch sizes
|
||||||
@@ -345,8 +346,8 @@ METAL_FUNC void ReadWriter<float, float2>::write_padded(
|
|||||||
int length,
|
int length,
|
||||||
const device float2* w_k) const {
|
const device float2* w_k) const {
|
||||||
int length_over_2 = (length / 2) + 1;
|
int length_over_2 = (length / 2) + 1;
|
||||||
int batch_idx =
|
size_t batch_idx =
|
||||||
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
||||||
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
||||||
|
|
||||||
int grid_index = elem.x * grid.y + elem.y;
|
int grid_index = elem.x * grid.y + elem.y;
|
||||||
@@ -397,7 +398,8 @@ METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
|
|||||||
template <>
|
template <>
|
||||||
METAL_FUNC void ReadWriter<float2, float>::load() const {
|
METAL_FUNC void ReadWriter<float2, float>::load() const {
|
||||||
short n_over_2 = (n / 2) + 1;
|
short n_over_2 = (n / 2) + 1;
|
||||||
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
size_t batch_idx =
|
||||||
|
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
||||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||||
|
|
||||||
// No out of bounds accesses on odd batch sizes
|
// No out of bounds accesses on odd batch sizes
|
||||||
@@ -458,8 +460,8 @@ METAL_FUNC void ReadWriter<float2, float>::load_padded(
|
|||||||
int n_over_2 = (n / 2) + 1;
|
int n_over_2 = (n / 2) + 1;
|
||||||
int length_over_2 = (length / 2) + 1;
|
int length_over_2 = (length / 2) + 1;
|
||||||
|
|
||||||
int batch_idx =
|
size_t batch_idx =
|
||||||
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
||||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||||
|
|
||||||
// No out of bounds accesses on odd batch sizes
|
// No out of bounds accesses on odd batch sizes
|
||||||
@@ -503,7 +505,7 @@ template <>
|
|||||||
METAL_FUNC void ReadWriter<float2, float>::write_padded(
|
METAL_FUNC void ReadWriter<float2, float>::write_padded(
|
||||||
int length,
|
int length,
|
||||||
const device float2* w_k) const {
|
const device float2* w_k) const {
|
||||||
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
|
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
|
||||||
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
||||||
|
|
||||||
int grid_index = elem.x * grid.y + elem.y;
|
int grid_index = elem.x * grid.y + elem.y;
|
||||||
|
|||||||
@@ -9,7 +9,41 @@ using namespace metal;
|
|||||||
|
|
||||||
constant bool has_w [[function_constant(20)]];
|
constant bool has_w [[function_constant(20)]];
|
||||||
|
|
||||||
template <typename T, int N_READS = RMS_N_READS>
|
template <int N = 1>
|
||||||
|
inline void initialize_buffer(
|
||||||
|
threadgroup float* xs,
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
if (simd_group_id == 0) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
xs[N * simd_lane_id + i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N = 1>
|
||||||
|
inline void threadgroup_sum(
|
||||||
|
thread float* x,
|
||||||
|
threadgroup float* xs,
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
x[i] = simd_sum(x[i]);
|
||||||
|
}
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
xs[N * simd_group_id + i] = x[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
x[i] = xs[N * simd_lane_id + i];
|
||||||
|
x[i] = simd_sum(x[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int N_READS = 8>
|
||||||
[[kernel]] void layer_norm_single_row(
|
[[kernel]] void layer_norm_single_row(
|
||||||
const device T* x,
|
const device T* x,
|
||||||
const device T* w,
|
const device T* w,
|
||||||
@@ -23,90 +57,71 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
uint lid [[thread_position_in_threadgroup]],
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
float sumx = 0;
|
|
||||||
float sumx2 = 0;
|
|
||||||
float thread_x[N_READS];
|
|
||||||
|
|
||||||
constexpr int SIMD_SIZE = 32;
|
constexpr int SIMD_SIZE = 32;
|
||||||
|
|
||||||
threadgroup float local_sumx[SIMD_SIZE];
|
// Initialize the registers and threadgroup memory
|
||||||
threadgroup float local_sumx2[SIMD_SIZE];
|
float thread_x[N_READS] = {0};
|
||||||
threadgroup float local_mean[1];
|
threadgroup float local_buffer[SIMD_SIZE] = {0};
|
||||||
threadgroup float local_normalizer[1];
|
initialize_buffer(local_buffer, simd_lane_id, simd_group_id);
|
||||||
|
|
||||||
|
// Advance the pointers
|
||||||
x += gid * size_t(axis_size) + lid * N_READS;
|
x += gid * size_t(axis_size) + lid * N_READS;
|
||||||
w += w_stride * lid * N_READS;
|
w += w_stride * lid * N_READS;
|
||||||
b += b_stride * lid * N_READS;
|
b += b_stride * lid * N_READS;
|
||||||
|
out += gid * size_t(axis_size) + lid * N_READS;
|
||||||
|
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
// Compute some variables for reading writing etc
|
||||||
|
const bool safe = lid * N_READS + N_READS <= axis_size;
|
||||||
|
const int n = axis_size - lid * N_READS;
|
||||||
|
|
||||||
|
// Read the inputs
|
||||||
|
if (safe) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
thread_x[i] = x[i];
|
thread_x[i] = x[i];
|
||||||
sumx2 += thread_x[i] * thread_x[i];
|
|
||||||
sumx += thread_x[i];
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
if ((lid * N_READS + i) < axis_size) {
|
thread_x[i] = x[i];
|
||||||
thread_x[i] = x[i];
|
|
||||||
sumx2 += thread_x[i] * thread_x[i];
|
|
||||||
sumx += thread_x[i];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sumx = simd_sum(sumx);
|
// Compute the mean
|
||||||
sumx2 = simd_sum(sumx2);
|
float mean = 0;
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
// Initialize shared memory
|
mean += thread_x[i];
|
||||||
if (simd_group_id == 0) {
|
|
||||||
local_sumx[simd_lane_id] = 0;
|
|
||||||
local_sumx2[simd_lane_id] = 0;
|
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
|
||||||
|
mean /= axis_size;
|
||||||
|
|
||||||
// Write simd accumulations into shared memory
|
// Compute the normalizer
|
||||||
if (simd_lane_id == 0) {
|
float normalizer = 0;
|
||||||
local_sumx[simd_group_id] = sumx;
|
if (!safe) {
|
||||||
local_sumx2[simd_group_id] = sumx2;
|
for (int i = n; i < N_READS; i++) {
|
||||||
}
|
thread_x[i] = mean;
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Accumulate over simd groups
|
|
||||||
if (simd_group_id == 0) {
|
|
||||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
|
||||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
|
||||||
if (simd_lane_id == 0) {
|
|
||||||
float mean = sumx / axis_size;
|
|
||||||
float variance = sumx2 / axis_size - mean * mean;
|
|
||||||
|
|
||||||
local_mean[0] = mean;
|
|
||||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
thread_x[i] -= mean;
|
||||||
float mean = local_mean[0];
|
normalizer += thread_x[i] * thread_x[i];
|
||||||
float normalizer = local_normalizer[0];
|
}
|
||||||
|
threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id);
|
||||||
|
normalizer = metal::precise::rsqrt(normalizer / axis_size + eps);
|
||||||
|
|
||||||
// Write the outputs
|
// Write the outputs
|
||||||
out += gid * size_t(axis_size) + lid * N_READS;
|
if (safe) {
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
thread_x[i] *= normalizer;
|
||||||
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
if ((lid * N_READS + i) < axis_size) {
|
thread_x[i] *= normalizer;
|
||||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
||||||
out[i] =
|
|
||||||
w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int N_READS = RMS_N_READS>
|
template <typename T, int N_READS = 4>
|
||||||
[[kernel]] void layer_norm_looped(
|
[[kernel]] void layer_norm_looped(
|
||||||
const device T* x,
|
const device T* x,
|
||||||
const device T* w,
|
const device T* w,
|
||||||
@@ -121,71 +136,52 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
uint lsize [[threads_per_threadgroup]],
|
uint lsize [[threads_per_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
float sumx = 0;
|
|
||||||
float sumx2 = 0;
|
|
||||||
|
|
||||||
constexpr int SIMD_SIZE = 32;
|
constexpr int SIMD_SIZE = 32;
|
||||||
|
|
||||||
threadgroup float local_sumx[SIMD_SIZE];
|
threadgroup float local_buffer[SIMD_SIZE];
|
||||||
threadgroup float local_sumx2[SIMD_SIZE];
|
initialize_buffer(local_buffer, simd_lane_id, simd_group_id);
|
||||||
threadgroup float local_mean[1];
|
|
||||||
threadgroup float local_normalizer[1];
|
|
||||||
|
|
||||||
x += gid * size_t(axis_size) + lid * N_READS;
|
x += gid * size_t(axis_size) + lid * N_READS;
|
||||||
w += w_stride * lid * N_READS;
|
w += w_stride * lid * N_READS;
|
||||||
b += b_stride * lid * N_READS;
|
b += b_stride * lid * N_READS;
|
||||||
|
|
||||||
|
// Compute the mean
|
||||||
|
float mean = 0;
|
||||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float xi = x[i + r];
|
mean += x[i + r];
|
||||||
sumx2 += xi * xi;
|
|
||||||
sumx += xi;
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if ((r + lid * N_READS + i) < axis_size) {
|
if ((r + lid * N_READS + i) < axis_size) {
|
||||||
float xi = x[i + r];
|
mean += x[i + r];
|
||||||
sumx2 += xi * xi;
|
|
||||||
sumx += xi;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
|
||||||
|
mean /= axis_size;
|
||||||
|
|
||||||
sumx = simd_sum(sumx);
|
// Compute the normalizer
|
||||||
sumx2 = simd_sum(sumx2);
|
float normalizer = 0;
|
||||||
|
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||||
// Initialize shared memory
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||||
if (simd_group_id == 0) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
local_sumx[simd_lane_id] = 0;
|
float t = x[i + r] - mean;
|
||||||
local_sumx2[simd_lane_id] = 0;
|
normalizer += t * t;
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
} else {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
// Write simd accumulations into shared memory
|
if ((r + lid * N_READS + i) < axis_size) {
|
||||||
if (simd_lane_id == 0) {
|
float t = x[i + r] - mean;
|
||||||
local_sumx[simd_group_id] = sumx;
|
normalizer += t * t;
|
||||||
local_sumx2[simd_group_id] = sumx2;
|
}
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Accumulate over simd groups
|
|
||||||
if (simd_group_id == 0) {
|
|
||||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
|
||||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
|
||||||
if (simd_lane_id == 0) {
|
|
||||||
float mean = sumx / axis_size;
|
|
||||||
float variance = sumx2 / axis_size - mean * mean;
|
|
||||||
|
|
||||||
local_mean[0] = mean;
|
|
||||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id);
|
||||||
|
normalizer = metal::precise::rsqrt(normalizer / axis_size + eps);
|
||||||
float mean = local_mean[0];
|
|
||||||
float normalizer = local_normalizer[0];
|
|
||||||
|
|
||||||
// Write the outputs
|
// Write the outputs
|
||||||
out += gid * size_t(axis_size) + lid * N_READS;
|
out += gid * size_t(axis_size) + lid * N_READS;
|
||||||
@@ -208,7 +204,7 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int N_READS = RMS_N_READS>
|
template <typename T, int N_READS = 8>
|
||||||
[[kernel]] void vjp_layer_norm_single_row(
|
[[kernel]] void vjp_layer_norm_single_row(
|
||||||
const device T* x,
|
const device T* x,
|
||||||
const device T* w,
|
const device T* w,
|
||||||
@@ -222,133 +218,96 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
uint lid [[thread_position_in_threadgroup]],
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
constexpr int SIMD_SIZE = 32;
|
||||||
|
|
||||||
// Advance the input pointers
|
// Advance the input pointers
|
||||||
x += gid * size_t(axis_size) + lid * N_READS;
|
x += gid * size_t(axis_size) + lid * N_READS;
|
||||||
g += gid * size_t(axis_size) + lid * N_READS;
|
g += gid * size_t(axis_size) + lid * N_READS;
|
||||||
w += w_stride * lid * N_READS;
|
w += w_stride * lid * N_READS;
|
||||||
|
|
||||||
// Allocate registers for the computation and accumulators
|
// Initialize the registers and threadgroup memory
|
||||||
float thread_x[N_READS];
|
float thread_x[N_READS] = {0};
|
||||||
float thread_w[N_READS];
|
float thread_w[N_READS] = {0};
|
||||||
float thread_g[N_READS];
|
float thread_g[N_READS] = {0};
|
||||||
float sumx = 0;
|
threadgroup float local_buffer[3 * SIMD_SIZE];
|
||||||
float sumx2 = 0;
|
initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id);
|
||||||
float sumwg = 0;
|
|
||||||
float sumwgx = 0;
|
|
||||||
|
|
||||||
constexpr int SIMD_SIZE = 32;
|
// Compute some variables for reading writing etc
|
||||||
|
const bool safe = lid * N_READS + N_READS <= axis_size;
|
||||||
|
const int n = axis_size - lid * N_READS;
|
||||||
|
|
||||||
threadgroup float local_sumx[SIMD_SIZE];
|
// Read the inputs
|
||||||
threadgroup float local_sumx2[SIMD_SIZE];
|
if (safe) {
|
||||||
threadgroup float local_sumwg[SIMD_SIZE];
|
|
||||||
threadgroup float local_sumwgx[SIMD_SIZE];
|
|
||||||
threadgroup float local_mean[1];
|
|
||||||
threadgroup float local_normalizer[1];
|
|
||||||
threadgroup float local_meanwg[1];
|
|
||||||
threadgroup float local_meanwgx[1];
|
|
||||||
|
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
thread_x[i] = x[i];
|
thread_x[i] = x[i];
|
||||||
thread_w[i] = w[i * w_stride];
|
|
||||||
thread_g[i] = g[i];
|
thread_g[i] = g[i];
|
||||||
float wg = thread_w[i] * thread_g[i];
|
thread_w[i] = w[i * w_stride];
|
||||||
sumx += thread_x[i];
|
|
||||||
sumx2 += thread_x[i] * thread_x[i];
|
|
||||||
sumwg += wg;
|
|
||||||
sumwgx += wg * thread_x[i];
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
if ((lid * N_READS + i) < axis_size) {
|
thread_x[i] = x[i];
|
||||||
thread_x[i] = x[i];
|
thread_g[i] = g[i];
|
||||||
thread_w[i] = w[i * w_stride];
|
thread_w[i] = w[i * w_stride];
|
||||||
thread_g[i] = g[i];
|
|
||||||
float wg = thread_w[i] * thread_g[i];
|
|
||||||
sumx += thread_x[i];
|
|
||||||
sumx2 += thread_x[i] * thread_x[i];
|
|
||||||
sumwg += wg;
|
|
||||||
sumwgx += wg * thread_x[i];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sumx = simd_sum(sumx);
|
// Compute the mean
|
||||||
sumx2 = simd_sum(sumx2);
|
float mean = 0;
|
||||||
sumwg = simd_sum(sumwg);
|
for (int i = 0; i < N_READS; i++) {
|
||||||
sumwgx = simd_sum(sumwgx);
|
mean += thread_x[i];
|
||||||
|
|
||||||
// Initialize shared memory
|
|
||||||
if (simd_group_id == 0) {
|
|
||||||
local_sumx[simd_lane_id] = 0;
|
|
||||||
local_sumx2[simd_lane_id] = 0;
|
|
||||||
local_sumwg[simd_lane_id] = 0;
|
|
||||||
local_sumwgx[simd_lane_id] = 0;
|
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
|
||||||
|
mean /= axis_size;
|
||||||
|
|
||||||
// Write simd accumulations into shared memory
|
// Compute the neccesary scaling factors using the mean
|
||||||
if (simd_lane_id == 0) {
|
if (!safe) {
|
||||||
local_sumx[simd_group_id] = sumx;
|
for (int i = n; i < N_READS; i++) {
|
||||||
local_sumx2[simd_group_id] = sumx2;
|
thread_x[i] = mean;
|
||||||
local_sumwg[simd_group_id] = sumwg;
|
|
||||||
local_sumwgx[simd_group_id] = sumwgx;
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Accumulate over simd groups
|
|
||||||
if (simd_group_id == 0) {
|
|
||||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
|
||||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
|
||||||
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
|
||||||
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
|
||||||
if (simd_lane_id == 0) {
|
|
||||||
float mean = sumx / axis_size;
|
|
||||||
float variance = sumx2 / axis_size - mean * mean;
|
|
||||||
|
|
||||||
local_mean[0] = mean;
|
|
||||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
|
||||||
local_meanwg[0] = sumwg / axis_size;
|
|
||||||
local_meanwgx[0] = sumwgx / axis_size;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
float factors[3] = {0};
|
||||||
|
constexpr int meanwg = 0;
|
||||||
float mean = local_mean[0];
|
constexpr int meanwgxc = 1;
|
||||||
float normalizer = local_normalizer[0];
|
constexpr int normalizer2 = 2;
|
||||||
float meanwg = local_meanwg[0];
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
thread_x[i] -= mean;
|
||||||
float normalizer2 = normalizer * normalizer;
|
factors[meanwg] += thread_w[i] * thread_g[i];
|
||||||
|
factors[meanwgxc] += thread_w[i] * thread_g[i] * thread_x[i];
|
||||||
|
factors[normalizer2] += thread_x[i] * thread_x[i];
|
||||||
|
}
|
||||||
|
threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id);
|
||||||
|
factors[meanwg] /= axis_size;
|
||||||
|
factors[meanwgxc] /= axis_size;
|
||||||
|
factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps);
|
||||||
|
float normalizer = metal::precise::sqrt(factors[normalizer2]);
|
||||||
|
|
||||||
// Write the outputs
|
// Write the outputs
|
||||||
gx += gid * size_t(axis_size) + lid * N_READS;
|
gx += gid * size_t(axis_size) + lid * N_READS;
|
||||||
gw += gid * size_t(axis_size) + lid * N_READS;
|
gw += gid * size_t(axis_size) + lid * N_READS;
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
if (safe) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
thread_x[i] *= normalizer;
|
||||||
gx[i] = static_cast<T>(
|
gx[i] = static_cast<T>(
|
||||||
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) -
|
||||||
thread_x[i] * meanwgxc * normalizer2);
|
thread_x[i] * factors[meanwgxc] * factors[normalizer2]);
|
||||||
if (has_w) {
|
if (has_w) {
|
||||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
if ((lid * N_READS + i) < axis_size) {
|
thread_x[i] *= normalizer;
|
||||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
gx[i] = static_cast<T>(
|
||||||
gx[i] = static_cast<T>(
|
normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) -
|
||||||
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
thread_x[i] * factors[meanwgxc] * factors[normalizer2]);
|
||||||
thread_x[i] * meanwgxc * normalizer2);
|
if (has_w) {
|
||||||
if (has_w) {
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int N_READS = RMS_N_READS>
|
template <typename T, int N_READS = 4>
|
||||||
[[kernel]] void vjp_layer_norm_looped(
|
[[kernel]] void vjp_layer_norm_looped(
|
||||||
const device T* x,
|
const device T* x,
|
||||||
const device T* w,
|
const device T* w,
|
||||||
@@ -363,102 +322,69 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
uint lsize [[threads_per_threadgroup]],
|
uint lsize [[threads_per_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
constexpr int SIMD_SIZE = 32;
|
||||||
|
|
||||||
// Advance the input pointers
|
// Advance the input pointers
|
||||||
x += gid * size_t(axis_size) + lid * N_READS;
|
x += gid * size_t(axis_size) + lid * N_READS;
|
||||||
g += gid * size_t(axis_size) + lid * N_READS;
|
g += gid * size_t(axis_size) + lid * N_READS;
|
||||||
w += w_stride * lid * N_READS;
|
w += w_stride * lid * N_READS;
|
||||||
|
|
||||||
// Allocate registers for the accumulators
|
threadgroup float local_buffer[3 * SIMD_SIZE];
|
||||||
float sumx = 0;
|
initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id);
|
||||||
float sumx2 = 0;
|
|
||||||
float sumwg = 0;
|
|
||||||
float sumwgx = 0;
|
|
||||||
|
|
||||||
constexpr int SIMD_SIZE = 32;
|
|
||||||
|
|
||||||
threadgroup float local_sumx[SIMD_SIZE];
|
|
||||||
threadgroup float local_sumx2[SIMD_SIZE];
|
|
||||||
threadgroup float local_sumwg[SIMD_SIZE];
|
|
||||||
threadgroup float local_sumwgx[SIMD_SIZE];
|
|
||||||
threadgroup float local_mean[1];
|
|
||||||
threadgroup float local_normalizer[1];
|
|
||||||
threadgroup float local_meanwg[1];
|
|
||||||
threadgroup float local_meanwgx[1];
|
|
||||||
|
|
||||||
|
// Compute the mean
|
||||||
|
float mean = 0;
|
||||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float xi = x[i + r];
|
mean += x[i + r];
|
||||||
float wi = w[(i + r) * w_stride];
|
|
||||||
float gi = g[i + r];
|
|
||||||
float wg = wi * gi;
|
|
||||||
sumx += xi;
|
|
||||||
sumx2 += xi * xi;
|
|
||||||
sumwg += wg;
|
|
||||||
sumwgx += wg * xi;
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if ((r + lid * N_READS + i) < axis_size) {
|
if ((r + lid * N_READS + i) < axis_size) {
|
||||||
float xi = x[i + r];
|
mean += x[i + r];
|
||||||
float wi = w[(i + r) * w_stride];
|
|
||||||
float gi = g[i + r];
|
|
||||||
float wg = wi * gi;
|
|
||||||
sumx += xi;
|
|
||||||
sumx2 += xi * xi;
|
|
||||||
sumwg += wg;
|
|
||||||
sumwgx += wg * xi;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
|
||||||
|
mean /= axis_size;
|
||||||
|
|
||||||
sumx = simd_sum(sumx);
|
// Compute the neccesary scaling factors using the mean
|
||||||
sumx2 = simd_sum(sumx2);
|
float factors[3] = {0};
|
||||||
sumwg = simd_sum(sumwg);
|
constexpr int meanwg = 0;
|
||||||
sumwgx = simd_sum(sumwgx);
|
constexpr int meanwgxc = 1;
|
||||||
|
constexpr int normalizer2 = 2;
|
||||||
// Initialize shared memory
|
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||||
if (simd_group_id == 0) {
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||||
local_sumx[simd_lane_id] = 0;
|
for (int i = 0; i < N_READS; i++) {
|
||||||
local_sumx2[simd_lane_id] = 0;
|
float t = x[i + r] - mean;
|
||||||
local_sumwg[simd_lane_id] = 0;
|
float wi = w[(i + r) * w_stride];
|
||||||
local_sumwgx[simd_lane_id] = 0;
|
float gi = g[i + r];
|
||||||
}
|
float wg = wi * gi;
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
factors[meanwg] += wg;
|
||||||
|
factors[meanwgxc] += wg * t;
|
||||||
// Write simd accumulations into shared memory
|
factors[normalizer2] += t * t;
|
||||||
if (simd_lane_id == 0) {
|
}
|
||||||
local_sumx[simd_group_id] = sumx;
|
} else {
|
||||||
local_sumx2[simd_group_id] = sumx2;
|
for (int i = 0; i < N_READS; i++) {
|
||||||
local_sumwg[simd_group_id] = sumwg;
|
if ((r + lid * N_READS + i) < axis_size) {
|
||||||
local_sumwgx[simd_group_id] = sumwgx;
|
float t = x[i + r] - mean;
|
||||||
}
|
float wi = w[(i + r) * w_stride];
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
float gi = g[i + r];
|
||||||
|
float wg = wi * gi;
|
||||||
// Accumulate over simd groups
|
factors[meanwg] += wg;
|
||||||
if (simd_group_id == 0) {
|
factors[meanwgxc] += wg * t;
|
||||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
factors[normalizer2] += t * t;
|
||||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
}
|
||||||
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
}
|
||||||
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
|
||||||
if (simd_lane_id == 0) {
|
|
||||||
float mean = sumx / axis_size;
|
|
||||||
float variance = sumx2 / axis_size - mean * mean;
|
|
||||||
|
|
||||||
local_mean[0] = mean;
|
|
||||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
|
||||||
local_meanwg[0] = sumwg / axis_size;
|
|
||||||
local_meanwgx[0] = sumwgx / axis_size;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id);
|
||||||
|
factors[meanwg] /= axis_size;
|
||||||
float mean = local_mean[0];
|
factors[meanwgxc] /= axis_size;
|
||||||
float normalizer = local_normalizer[0];
|
factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps);
|
||||||
float meanwg = local_meanwg[0];
|
float normalizer = metal::precise::sqrt(factors[normalizer2]);
|
||||||
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
|
||||||
float normalizer2 = normalizer * normalizer;
|
|
||||||
|
|
||||||
// Write the outputs
|
// Write the outputs
|
||||||
gx += gid * size_t(axis_size) + lid * N_READS;
|
gx += gid * size_t(axis_size) + lid * N_READS;
|
||||||
@@ -470,7 +396,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float wi = w[(i + r) * w_stride];
|
float wi = w[(i + r) * w_stride];
|
||||||
float gi = g[i + r];
|
float gi = g[i + r];
|
||||||
gx[i + r] = static_cast<T>(
|
gx[i + r] = static_cast<T>(
|
||||||
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
normalizer * (wi * gi - factors[meanwg]) -
|
||||||
|
xi * factors[meanwgxc] * factors[normalizer2]);
|
||||||
if (has_w) {
|
if (has_w) {
|
||||||
gw[i + r] = static_cast<T>(gi * xi);
|
gw[i + r] = static_cast<T>(gi * xi);
|
||||||
}
|
}
|
||||||
@@ -482,7 +409,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float wi = w[(i + r) * w_stride];
|
float wi = w[(i + r) * w_stride];
|
||||||
float gi = g[i + r];
|
float gi = g[i + r];
|
||||||
gx[i + r] = static_cast<T>(
|
gx[i + r] = static_cast<T>(
|
||||||
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
normalizer * (wi * gi - factors[meanwg]) -
|
||||||
|
xi * factors[meanwgxc] * factors[normalizer2]);
|
||||||
if (has_w) {
|
if (has_w) {
|
||||||
gw[i + r] = static_cast<T>(gi * xi);
|
gw[i + r] = static_cast<T>(gi * xi);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -103,8 +103,8 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
|
vals[i] =
|
||||||
: Limits<AccT>::finite_min;
|
(offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
prevmax = maxval;
|
prevmax = maxval;
|
||||||
@@ -134,10 +134,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||||
|
|
||||||
if (simd_group_id == 0) {
|
if (lid == 0) {
|
||||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
||||||
if (simd_lane_id == 0) {
|
|
||||||
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,11 +14,23 @@ using namespace metal;
|
|||||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||||
MLX_MTL_CONST int QUAD_SIZE = 4;
|
MLX_MTL_CONST int QUAD_SIZE = 4;
|
||||||
|
|
||||||
|
template <int bits, int wsize = 8>
|
||||||
|
inline constexpr short get_pack_factor() {
|
||||||
|
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int bits, int wsize = 8>
|
||||||
|
inline constexpr short get_bytes_per_pack() {
|
||||||
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
|
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int values_per_thread, int bits>
|
template <typename T, typename U, int values_per_thread, int bits>
|
||||||
inline U load_vector(const device T* x, thread U* x_thread) {
|
inline U load_vector(const device T* x, thread U* x_thread) {
|
||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
||||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
||||||
|
|
||||||
U sum = 0;
|
U sum = 0;
|
||||||
|
|
||||||
@@ -57,6 +69,21 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if (bits == 5) {
|
||||||
|
for (int i = 0; i < values_per_thread; i += 8) {
|
||||||
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
||||||
|
x[i + 6] + x[i + 7];
|
||||||
|
x_thread[i] = x[i];
|
||||||
|
x_thread[i + 1] = x[i + 1] / 32.0f;
|
||||||
|
x_thread[i + 2] = x[i + 2] / 4.0f;
|
||||||
|
x_thread[i + 3] = x[i + 3] / 128.0f;
|
||||||
|
x_thread[i + 4] = x[i + 4] / 16.0f;
|
||||||
|
x_thread[i + 5] = x[i + 5] / 2.0f;
|
||||||
|
x_thread[i + 6] = x[i + 6] / 64.0f;
|
||||||
|
x_thread[i + 7] = x[i + 7] / 8.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
else if (bits == 6) {
|
else if (bits == 6) {
|
||||||
for (int i = 0; i < values_per_thread; i += 4) {
|
for (int i = 0; i < values_per_thread; i += 4) {
|
||||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||||
@@ -80,8 +107,9 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
|||||||
template <typename T, typename U, int values_per_thread, int bits>
|
template <typename T, typename U, int values_per_thread, int bits>
|
||||||
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
||||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
||||||
|
|
||||||
U sum = 0;
|
U sum = 0;
|
||||||
|
|
||||||
@@ -121,6 +149,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if (bits == 5) {
|
||||||
|
for (int i = 0; i < N; i += 8) {
|
||||||
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
||||||
|
x[i + 6] + x[i + 7];
|
||||||
|
x_thread[i] = x[i];
|
||||||
|
x_thread[i + 1] = x[i + 1] / 32.0f;
|
||||||
|
x_thread[i + 2] = x[i + 2] / 4.0f;
|
||||||
|
x_thread[i + 3] = x[i + 3] / 128.0f;
|
||||||
|
x_thread[i + 4] = x[i + 4] / 16.0f;
|
||||||
|
x_thread[i + 5] = x[i + 5] / 2.0f;
|
||||||
|
x_thread[i + 6] = x[i + 6] / 64.0f;
|
||||||
|
x_thread[i + 7] = x[i + 7] / 8.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
else if (bits == 6) {
|
else if (bits == 6) {
|
||||||
for (int i = 0; i < N; i += 4) {
|
for (int i = 0; i < N; i += 4) {
|
||||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||||
@@ -153,8 +196,9 @@ inline U qdot(
|
|||||||
U bias,
|
U bias,
|
||||||
U sum) {
|
U sum) {
|
||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
||||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
||||||
|
|
||||||
U accum = 0;
|
U accum = 0;
|
||||||
|
|
||||||
@@ -199,6 +243,26 @@ inline U qdot(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if (bits == 5) {
|
||||||
|
for (int i = 0; i < (values_per_thread / 8); i++) {
|
||||||
|
x_thread += 8 * i;
|
||||||
|
w += 5 * i;
|
||||||
|
|
||||||
|
accum += (w[0] & 0x1f) * x_thread[0];
|
||||||
|
accum += (w[0] & 0xe0) * x_thread[1];
|
||||||
|
accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
|
||||||
|
accum += (w[1] & 0x7c) * x_thread[2];
|
||||||
|
accum += (w[1] & 0x80) * x_thread[3];
|
||||||
|
accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
|
||||||
|
accum += (w[2] & 0xf0) * x_thread[4];
|
||||||
|
accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
|
||||||
|
accum += (w[3] & 0x3e) * x_thread[5];
|
||||||
|
accum += (w[3] & 0xc0) * x_thread[6];
|
||||||
|
accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
|
||||||
|
accum += (w[4] & 0xf8) * x_thread[7];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
else if (bits == 6) {
|
else if (bits == 6) {
|
||||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
x_thread += 4 * i;
|
x_thread += 4 * i;
|
||||||
@@ -234,8 +298,9 @@ inline U qdot_safe(
|
|||||||
U sum,
|
U sum,
|
||||||
int N) {
|
int N) {
|
||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
||||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
||||||
|
|
||||||
U accum = 0;
|
U accum = 0;
|
||||||
|
|
||||||
@@ -280,6 +345,26 @@ inline U qdot_safe(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if (bits == 5) {
|
||||||
|
for (int i = 0; i < (N / 8); i++) {
|
||||||
|
x_thread += 8 * i;
|
||||||
|
w += 5 * i;
|
||||||
|
|
||||||
|
accum += (w[0] & 0x1f) * x_thread[0];
|
||||||
|
accum += (w[0] & 0xe0) * x_thread[1];
|
||||||
|
accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
|
||||||
|
accum += (w[1] & 0x7c) * x_thread[2];
|
||||||
|
accum += (w[1] & 0x80) * x_thread[3];
|
||||||
|
accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
|
||||||
|
accum += (w[2] & 0xf0) * x_thread[4];
|
||||||
|
accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
|
||||||
|
accum += (w[3] & 0x3e) * x_thread[5];
|
||||||
|
accum += (w[3] & 0xc0) * x_thread[6];
|
||||||
|
accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
|
||||||
|
accum += (w[4] & 0xf8) * x_thread[7];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
else if (bits == 6) {
|
else if (bits == 6) {
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
x_thread += 4 * i;
|
x_thread += 4 * i;
|
||||||
@@ -310,8 +395,9 @@ template <typename U, int values_per_thread, int bits>
|
|||||||
inline void
|
inline void
|
||||||
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
||||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
||||||
@@ -348,8 +434,31 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
|||||||
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
|
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
|
||||||
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
|
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} else if (bits == 6) {
|
else if (bits == 5) {
|
||||||
|
for (int i = 0; i < (values_per_thread / 8); i++) {
|
||||||
|
uint8_t w0 = w[5 * i];
|
||||||
|
uint8_t w1 = w[5 * i + 1];
|
||||||
|
uint8_t w2 = w[5 * i + 2];
|
||||||
|
uint8_t w3 = w[5 * i + 3];
|
||||||
|
uint8_t w4 = w[5 * i + 4];
|
||||||
|
result[8 * i] += x * ((w0 & 0x1f) * scale + bias);
|
||||||
|
result[8 * i + 1] +=
|
||||||
|
x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias);
|
||||||
|
result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias);
|
||||||
|
result[8 * i + 3] +=
|
||||||
|
x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias);
|
||||||
|
result[8 * i + 4] +=
|
||||||
|
x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias);
|
||||||
|
result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias);
|
||||||
|
result[8 * i + 6] +=
|
||||||
|
x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias);
|
||||||
|
result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
else if (bits == 6) {
|
||||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
uint8_t w0 = w[3 * i];
|
uint8_t w0 = w[3 * i];
|
||||||
uint8_t w1 = w[3 * i + 1];
|
uint8_t w1 = w[3 * i + 1];
|
||||||
@@ -375,8 +484,9 @@ template <typename U, int N, int bits>
|
|||||||
inline void
|
inline void
|
||||||
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
||||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
U s[4] = {
|
U s[4] = {
|
||||||
@@ -416,11 +526,26 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if (bits == 5) {
|
||||||
|
for (int i = 0; i < (N / 8); i++) {
|
||||||
|
w_local += 8 * i;
|
||||||
|
w += 5 * i;
|
||||||
|
|
||||||
|
w_local[0] = (w[0] & 0x1f) * scale + bias;
|
||||||
|
w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
|
||||||
|
w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
|
||||||
|
w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
|
||||||
|
w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
|
||||||
|
w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
|
||||||
|
w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
|
||||||
|
w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
else if (bits == 6) {
|
else if (bits == 6) {
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
w_local += 4 * i;
|
w_local += 4 * i;
|
||||||
w += 3 * i;
|
w += 3 * i;
|
||||||
|
|
||||||
w_local[0] = (w[0] & 0x3f) * scale + bias;
|
w_local[0] = (w[0] & 0x3f) * scale + bias;
|
||||||
w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
|
w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
|
||||||
w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
|
w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
|
||||||
@@ -452,11 +577,12 @@ struct QuantizedBlockLoader {
|
|||||||
group_size % BCOLS == 0,
|
group_size % BCOLS == 0,
|
||||||
"The group size should be divisible by the columns");
|
"The group size should be divisible by the columns");
|
||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
||||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
||||||
|
|
||||||
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>();
|
||||||
MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>();
|
||||||
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
||||||
MLX_MTL_CONST short n_reads =
|
MLX_MTL_CONST short n_reads =
|
||||||
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
||||||
@@ -632,12 +758,11 @@ METAL_FUNC void qmv_fast_impl(
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
||||||
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
|
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int results_per_simdgroup = 4;
|
constexpr int results_per_simdgroup = 4;
|
||||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
constexpr int pack_factor = get_pack_factor<bits, 32>();
|
||||||
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
|
||||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||||
@@ -700,12 +825,12 @@ METAL_FUNC void qmv_impl(
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int results_per_simdgroup = 4;
|
constexpr int results_per_simdgroup = 4;
|
||||||
constexpr int packs_per_thread = 1;
|
constexpr int packs_per_thread = 1;
|
||||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
constexpr int pack_factor = get_pack_factor<bits, 32>();
|
||||||
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
|
||||||
|
|
||||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||||
@@ -857,8 +982,9 @@ METAL_FUNC void qvm_impl(
|
|||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
constexpr int pack_factor = get_pack_factor<bits, 32>();
|
||||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
||||||
|
|
||||||
constexpr int tn = 32 / pack_factor;
|
constexpr int tn = 32 / pack_factor;
|
||||||
constexpr int block_size = SIMD_SIZE;
|
constexpr int block_size = SIMD_SIZE;
|
||||||
|
|
||||||
@@ -981,9 +1107,10 @@ METAL_FUNC void qmm_t_impl(
|
|||||||
|
|
||||||
constexpr int WM = 2;
|
constexpr int WM = 2;
|
||||||
constexpr int WN = 2;
|
constexpr int WN = 2;
|
||||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
||||||
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
||||||
|
|
||||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
|
||||||
|
|
||||||
// Instantiate the appropriate BlockMMA and Loader
|
// Instantiate the appropriate BlockMMA and Loader
|
||||||
using mma_t = mlx::steel::
|
using mma_t = mlx::steel::
|
||||||
@@ -1106,11 +1233,11 @@ METAL_FUNC void qmm_n_impl(
|
|||||||
|
|
||||||
constexpr int WM = 2;
|
constexpr int WM = 2;
|
||||||
constexpr int WN = 2;
|
constexpr int WN = 2;
|
||||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
||||||
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
||||||
|
|
||||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
||||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
||||||
|
|
||||||
// Instantiate the appropriate BlockMMA and Loader
|
// Instantiate the appropriate BlockMMA and Loader
|
||||||
using mma_t = mlx::steel::
|
using mma_t = mlx::steel::
|
||||||
@@ -2120,11 +2247,10 @@ template <
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
||||||
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
||||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
||||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
||||||
|
|
||||||
using mma_t = mlx::steel::BlockMMA<
|
using mma_t = mlx::steel::BlockMMA<
|
||||||
T,
|
T,
|
||||||
@@ -2305,13 +2431,13 @@ template <typename T, const int group_size, const int bits>
|
|||||||
constexpr float eps = 1e-7;
|
constexpr float eps = 1e-7;
|
||||||
constexpr int simd_size = 32;
|
constexpr int simd_size = 32;
|
||||||
constexpr float n_bins = (1 << bits) - 1;
|
constexpr float n_bins = (1 << bits) - 1;
|
||||||
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
||||||
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
||||||
constexpr int values_per_reduce = group_size / simd_size;
|
constexpr int values_per_reduce = group_size / simd_size;
|
||||||
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
|
constexpr int writes_per_reduce = pack_factor / values_per_reduce;
|
||||||
constexpr int writes_per_pack =
|
constexpr int writes_per_pack =
|
||||||
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
|
writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;
|
||||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
||||||
|
|
||||||
static_assert(
|
static_assert(
|
||||||
group_size % simd_size == 0,
|
group_size % simd_size == 0,
|
||||||
@@ -2354,8 +2480,8 @@ template <typename T, const int group_size, const int bits>
|
|||||||
biases[gindex] = static_cast<T>(bias);
|
biases[gindex] = static_cast<T>(bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
|
using OutType = metal::conditional_t<bits == 5, uint64_t, uint32_t>;
|
||||||
uint32_t output = 0;
|
OutType output = 0;
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int i = 0; i < values_per_reduce; i++) {
|
for (int i = 0; i < values_per_reduce; i++) {
|
||||||
@@ -2363,27 +2489,35 @@ template <typename T, const int group_size, const int bits>
|
|||||||
if (bits == 8) {
|
if (bits == 8) {
|
||||||
output = val;
|
output = val;
|
||||||
} else {
|
} else {
|
||||||
output += val << (bits * (i % packs_per_int));
|
output |= val << (bits * (i % pack_factor));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (packs_per_int < values_per_reduce &&
|
if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) {
|
||||||
i % packs_per_int == packs_per_int - 1) {
|
out[out_index + i / pack_factor] = output;
|
||||||
out[out_index + i / packs_per_int] = output;
|
|
||||||
output = 0;
|
output = 0;
|
||||||
} else {
|
} else {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int j = 1; j < writes_per_reduce; j++) {
|
for (int j = 1; j < writes_per_reduce; j++) {
|
||||||
uint8_t sval = simd_shuffle_down(val, j);
|
uint8_t sval = simd_shuffle_down(val, j);
|
||||||
output += sval << (bits * (j * values_per_reduce + i));
|
output |= static_cast<OutType>(sval)
|
||||||
|
<< (bits * (j * values_per_reduce + i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (bits == 3 || bits == 6) {
|
if (bits == 3 || bits == 6) {
|
||||||
if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
|
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
|
||||||
out[out_index] = output & 0xff;
|
out[out_index] = output & 0xff;
|
||||||
out[out_index + 1] = (output & 0xff00) >> 8;
|
out[out_index + 1] = (output & 0xff00) >> 8;
|
||||||
out[out_index + 2] = (output & 0xff0000) >> 16;
|
out[out_index + 2] = (output & 0xff0000) >> 16;
|
||||||
}
|
}
|
||||||
|
} else if (bits == 5) {
|
||||||
|
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
|
||||||
|
out[out_index] = output & 0xff;
|
||||||
|
out[out_index + 1] = (output & 0xff00) >> 8;
|
||||||
|
out[out_index + 2] = (output & 0xff0000) >> 16;
|
||||||
|
out[out_index + 3] = (output & 0xff000000) >> 24;
|
||||||
|
out[out_index + 4] = (output & 0xff00000000) >> 32;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
|
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
|
||||||
out[out_index / writes_per_reduce] = output;
|
out[out_index / writes_per_reduce] = output;
|
||||||
@@ -2399,12 +2533,11 @@ template <typename T, const int group_size, const int bits>
|
|||||||
device T* out [[buffer(3)]],
|
device T* out [[buffer(3)]],
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
||||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
||||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
||||||
|
|
||||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||||
size_t oindex = offset * packs_per_int;
|
size_t oindex = offset * pack_factor;
|
||||||
size_t gindex = oindex / group_size;
|
size_t gindex = oindex / group_size;
|
||||||
T scale = scales[gindex];
|
T scale = scales[gindex];
|
||||||
T bias = biases[gindex];
|
T bias = biases[gindex];
|
||||||
@@ -2421,7 +2554,16 @@ template <typename T, const int group_size, const int bits>
|
|||||||
out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
|
out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
|
||||||
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
|
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
|
||||||
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
|
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
|
||||||
|
} else if (bits == 5) {
|
||||||
|
w += offset * bytes_per_pack;
|
||||||
|
out[0] = (w[0] & 0x1f) * scale + bias;
|
||||||
|
out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
|
||||||
|
out[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
|
||||||
|
out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
|
||||||
|
out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
|
||||||
|
out[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
|
||||||
|
out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
|
||||||
|
out[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
|
||||||
} else if (bits == 6) {
|
} else if (bits == 6) {
|
||||||
w += offset * bytes_per_pack;
|
w += offset * bytes_per_pack;
|
||||||
out[0] = (w[0] & 0x3f) * scale + bias;
|
out[0] = (w[0] & 0x3f) * scale + bias;
|
||||||
@@ -2431,7 +2573,7 @@ template <typename T, const int group_size, const int bits>
|
|||||||
} else {
|
} else {
|
||||||
uint val = w[offset];
|
uint val = w[offset];
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int i = 0; i < packs_per_int; i++) {
|
for (int i = 0; i < pack_factor; i++) {
|
||||||
uint8_t d;
|
uint8_t d;
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
d = (val >> (bits * i)) & 0x03;
|
d = (val >> (bits * i)) & 0x03;
|
||||||
|
|||||||
@@ -136,6 +136,7 @@
|
|||||||
instantiate_quantized_groups(2) \
|
instantiate_quantized_groups(2) \
|
||||||
instantiate_quantized_groups(3) \
|
instantiate_quantized_groups(3) \
|
||||||
instantiate_quantized_groups(4) \
|
instantiate_quantized_groups(4) \
|
||||||
|
instantiate_quantized_groups(5) \
|
||||||
instantiate_quantized_groups(6) \
|
instantiate_quantized_groups(6) \
|
||||||
instantiate_quantized_groups(8)
|
instantiate_quantized_groups(8)
|
||||||
|
|
||||||
|
|||||||
@@ -224,7 +224,7 @@ template <
|
|||||||
|
|
||||||
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
|
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
|
||||||
// Simple loop over non_row_reductions and reduce the row in the thread.
|
// Simple loop over non_row_reductions and reduce the row in the thread.
|
||||||
IdxT out_idx = tid.x + tsize.y * IdxT(tid.y);
|
IdxT out_idx = tid.x + tsize.x * IdxT(tid.y);
|
||||||
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
||||||
|
|
||||||
for (uint r = 0; r < non_row_reductions; r++) {
|
for (uint r = 0; r < non_row_reductions; r++) {
|
||||||
|
|||||||
@@ -128,8 +128,8 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
|
vals[i] =
|
||||||
: Limits<AccT>::finite_min;
|
(offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
prevmax = maxval;
|
prevmax = maxval;
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
||||||
|
|
||||||
|
constant bool align_C [[function_constant(200)]];
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
int BM,
|
int BM,
|
||||||
@@ -118,23 +120,58 @@ implicit_gemm_conv_2d_general(
|
|||||||
// Prepare threadgroup mma operation
|
// Prepare threadgroup mma operation
|
||||||
mma_t mma_op(simd_gid, simd_lid);
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
|
|
||||||
int gemm_k_iterations =
|
if (align_C) {
|
||||||
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
int gemm_k_iterations =
|
||||||
|
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
||||||
|
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// Load elements into threadgroup
|
// Load elements into threadgroup
|
||||||
loader_a.load_unsafe();
|
loader_a.load_unsafe();
|
||||||
loader_b.load_unsafe();
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
// Multiply and accumulate threadgroup elements
|
||||||
mma_op.mma(As, Bs);
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
// Prepare for next iteration
|
// Prepare for next iteration
|
||||||
loader_a.next();
|
loader_a.next();
|
||||||
loader_b.next();
|
loader_b.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
else {
|
||||||
|
for (int k = 1; k < gemm_params->gemm_k_iterations; k++) {
|
||||||
|
for (int j = 0; j < base_wh_size * base_ww_size; j++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const short remaining_k = params->C % BK;
|
||||||
|
for (int j = 0; j < base_wh_size * base_ww_size; j++) {
|
||||||
|
// Load elements into threadgroup
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
loader_a.load_safe(remaining_k);
|
||||||
|
loader_b.load_safe(remaining_k);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|||||||
@@ -381,6 +381,7 @@ struct Conv2DWeightBlockLoader {
|
|||||||
const constant MLXConvParams<2>* params;
|
const constant MLXConvParams<2>* params;
|
||||||
|
|
||||||
int weight_hw;
|
int weight_hw;
|
||||||
|
int weight_step;
|
||||||
|
|
||||||
const int read_n;
|
const int read_n;
|
||||||
const bool do_read;
|
const bool do_read;
|
||||||
@@ -402,6 +403,7 @@ struct Conv2DWeightBlockLoader {
|
|||||||
src(src_ + bi * src_ld + bj),
|
src(src_ + bi * src_ld + bj),
|
||||||
params(params_),
|
params(params_),
|
||||||
weight_hw(0),
|
weight_hw(0),
|
||||||
|
weight_step(params->C / params->groups),
|
||||||
read_n(offsets.y + bi),
|
read_n(offsets.y + bi),
|
||||||
do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
|
do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
|
||||||
|
|
||||||
@@ -435,15 +437,15 @@ struct Conv2DWeightBlockLoader {
|
|||||||
/* Iteration helper */
|
/* Iteration helper */
|
||||||
METAL_FUNC void next() {
|
METAL_FUNC void next() {
|
||||||
if (++weight_hw < (params->wS[1] * params->wS[0])) {
|
if (++weight_hw < (params->wS[1] * params->wS[0])) {
|
||||||
src += params->wt_strides[2];
|
src += weight_step;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
weight_hw = 0;
|
weight_hw = 0;
|
||||||
|
|
||||||
src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2];
|
src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace steel
|
} // namespace steel
|
||||||
} // namespace mlx
|
} // namespace mlx
|
||||||
|
|||||||
@@ -272,7 +272,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const device T* curr_src = src + weight_hw * params->wt_strides[2];
|
const device T* curr_src = src + weight_hw * (params->C / params->groups);
|
||||||
|
|
||||||
if (BN != 8 || do_read) {
|
if (BN != 8 || do_read) {
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
@@ -316,4 +316,4 @@ struct Conv2DWeightBlockLoaderSmallChannels {
|
|||||||
};
|
};
|
||||||
|
|
||||||
} // namespace steel
|
} // namespace steel
|
||||||
} // namespace mlx
|
} // namespace mlx
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user