mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 17:44:38 +08:00
Compare commits
68 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
716277d246 | ||
![]() |
ed4fb26cb9 | ||
![]() |
4640f865cc | ||
![]() |
0404037ea6 | ||
![]() |
990b1acc75 | ||
![]() |
d571366250 | ||
![]() |
791f50d9f3 | ||
![]() |
140301aea8 | ||
![]() |
0c22440c75 | ||
![]() |
c9ab537b9a | ||
![]() |
f1d87a2d3e | ||
![]() |
83c4f6bde6 | ||
![]() |
c1dc852995 | ||
![]() |
2cd1de0e47 | ||
![]() |
d927ed9e32 | ||
![]() |
168a3a464a | ||
![]() |
ad5b58b34e | ||
![]() |
0c5eea226b | ||
![]() |
dcca0d7477 | ||
![]() |
0d5e7716ad | ||
![]() |
d8c824c594 | ||
![]() |
cb431dfc9f | ||
![]() |
61d787726a | ||
![]() |
5e89aace9b | ||
![]() |
2af7e8a9a6 | ||
![]() |
2419edd5b2 | ||
![]() |
bf481e8e5d | ||
![]() |
9d7fa6b8e6 | ||
![]() |
073076ac7d | ||
![]() |
9bd03dd9b4 | ||
![]() |
6931f84412 | ||
![]() |
16ec0556a0 | ||
![]() |
610af352d4 | ||
![]() |
b35f1e3c9c | ||
![]() |
dfa0b9aab4 | ||
![]() |
a4c47b0276 | ||
![]() |
111fefd5e9 | ||
![]() |
c1fe1ef081 | ||
![]() |
8c34c9dac4 | ||
![]() |
91c0277356 | ||
![]() |
9f0d5c12fc | ||
![]() |
59247c2b62 | ||
![]() |
9a3842a2d9 | ||
![]() |
726dbd9267 | ||
![]() |
54f05e7195 | ||
![]() |
26be608470 | ||
![]() |
248431eb3c | ||
![]() |
76f275b4df | ||
![]() |
f1951d6cce | ||
![]() |
62f297b51d | ||
![]() |
09bc32f62f | ||
![]() |
46d8b16ab4 | ||
![]() |
42533931fa | ||
![]() |
9bd3a7102f | ||
![]() |
9e516b71ea | ||
![]() |
eac961ddb1 | ||
![]() |
57c6aa7188 | ||
![]() |
cde5b4ad80 | ||
![]() |
4f72c66911 | ||
![]() |
960e3f0f05 | ||
![]() |
884af42da2 | ||
![]() |
048fabdabd | ||
![]() |
917252a5a1 | ||
![]() |
1a992e31e8 | ||
![]() |
d2ff04a4f2 | ||
![]() |
015c247393 | ||
![]() |
d3cd26820e | ||
![]() |
91f6c499d7 |
@@ -349,7 +349,7 @@ workflows:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
- build_documentation:
|
||||
@@ -386,7 +386,7 @@ workflows:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
weekly_build:
|
||||
when:
|
||||
@@ -397,7 +397,7 @@ workflows:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
linux_test_release:
|
||||
@@ -409,5 +409,5 @@ workflows:
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
|
@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.19.1)
|
||||
set(MLX_VERSION 0.20.0)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
@@ -89,25 +89,27 @@ elseif(MLX_BUILD_METAL)
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(${MACOS_VERSION} LESS 14.0)
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
||||
endif()
|
||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||
|
||||
set(METAL_CPP_URL
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
|
||||
)
|
||||
# Get the metal version
|
||||
|
||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
execute_process(
|
||||
COMMAND
|
||||
zsh "-c"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
@@ -115,8 +117,6 @@ elseif(MLX_BUILD_METAL)
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||
$<INSTALL_INTERFACE:include/metal_cpp>)
|
||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||
|
||||
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
|
@@ -6,7 +6,7 @@
|
||||
|
||||
[](https://circleci.com/gh/ml-explore/mlx)
|
||||
|
||||
MLX is an array framework for machine learning research on Apple silicon,
|
||||
MLX is an array framework for machine learning on Apple silicon,
|
||||
brought to you by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
@@ -144,6 +144,13 @@ def reduction(op, axis, x):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def sum_and_add(axis, x, y):
|
||||
z = x.sum(axis=axis, keepdims=True)
|
||||
for i in range(50):
|
||||
z = (z + y).sum(axis=axis, keepdims=True)
|
||||
mx.eval(z)
|
||||
|
||||
|
||||
def softmax(axis, x):
|
||||
ys = []
|
||||
for i in range(100):
|
||||
@@ -505,5 +512,8 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
elif args.benchmark == "sum_and_add":
|
||||
print(bench(sum_and_add, axis, *xs))
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown benchmark")
|
||||
|
@@ -9,7 +9,7 @@ from time_utils import measure_runtime
|
||||
|
||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||
def scatter(dst, x, idx):
|
||||
dst[*idx] = x
|
||||
dst[tuple(idx)] = x
|
||||
mx.eval(dst)
|
||||
|
||||
idx = []
|
||||
@@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||
|
||||
|
||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||
def gather(dst, x, idx, device):
|
||||
dst[*idx] = x
|
||||
def scatter(dst, x, idx, device):
|
||||
dst[tuple(idx)] = x
|
||||
if device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
@@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||
|
||||
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
|
||||
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
|
||||
print(f"PyTorch: {runtime:.3f}ms")
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if __name__ == "__main__":
|
||||
(100_000, 64),
|
||||
(1_000_000, 64),
|
||||
(100_000,),
|
||||
(2_000_00,),
|
||||
(200_000,),
|
||||
(20_000_000,),
|
||||
(10000, 64),
|
||||
(100, 64),
|
||||
@@ -91,6 +91,6 @@ if __name__ == "__main__":
|
||||
|
||||
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||
print("=" * 20)
|
||||
print(f"X {x_shape}, Indices {idx_shape}")
|
||||
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
|
||||
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
||||
|
@@ -1,62 +1,189 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
import numpy as np
|
||||
|
||||
MAX_SEQ = 300
|
||||
START_SEQ = 100
|
||||
SEQ_INCREMENT = 50
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 5
|
||||
N_iter_bench = 40
|
||||
N_iter_func = 8
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
def bench(f, *args):
|
||||
for i in range(N_warmup):
|
||||
f(*args)
|
||||
|
||||
def sdpa_primitives(qs, ks, vs, alpha):
|
||||
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ vs
|
||||
return o
|
||||
|
||||
time_fn(sdpa_primitives, q, k, v, scale)
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(*args)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
def mlx_sdpa_fused_inner(q, k, v, scale):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
|
||||
|
||||
def sdpa_fused(qs, ks, vs, alpha):
|
||||
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
|
||||
return o
|
||||
|
||||
time_fn(sdpa_fused, q, k, v, scale)
|
||||
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
||||
q_dtype = q.dtype
|
||||
q = q * mx.array(scale, q_dtype)
|
||||
n_q_heads = q.shape[-3]
|
||||
n_kv_heads = k.shape[-3]
|
||||
n_repeats = n_q_heads // n_kv_heads
|
||||
|
||||
B = q.shape[0]
|
||||
L = q.shape[2]
|
||||
|
||||
if n_repeats > 1:
|
||||
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||
k = mx.expand_dims(k, 2)
|
||||
v = mx.expand_dims(v, 2)
|
||||
|
||||
scores = q @ mx.swapaxes(k, -1, -2)
|
||||
if f32softmax:
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
|
||||
else:
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
|
||||
out = scores @ v
|
||||
if n_repeats > 1:
|
||||
out = mx.reshape(out, [B, n_q_heads, L, -1])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def mlx_spda_unfused(q, k, v, scale, transpose):
|
||||
q_out = q
|
||||
if transpose:
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
for i in range(N_iter_func):
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def mlx_spda_fused(q, k, v, scale, transpose):
|
||||
q_out = q
|
||||
if transpose:
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
for i in range(N_iter_func):
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
|
||||
shape_q = (
|
||||
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
|
||||
)
|
||||
shape_kv = (
|
||||
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
|
||||
)
|
||||
|
||||
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
|
||||
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||
|
||||
scale = math.sqrt(1.0 / head_dim)
|
||||
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
|
||||
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
|
||||
|
||||
if transpose:
|
||||
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
|
||||
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
|
||||
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
|
||||
|
||||
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
|
||||
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
|
||||
|
||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
|
||||
print(
|
||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||
)
|
||||
|
||||
return time_mlx_fused, time_mlx_unfused
|
||||
|
||||
|
||||
def get_gflop_count(B, M, N, K):
|
||||
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if args.gpu:
|
||||
mx.set_default_device(mx.gpu)
|
||||
else:
|
||||
mx.set_default_device(mx.cpu)
|
||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
dtypes = ("float16", "float32")[:1]
|
||||
transposes = (False,)
|
||||
|
||||
# fmt: off
|
||||
shapes_64 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 32, 32, 64, 32, 32),
|
||||
( 1, 64, 64, 64, 32, 32),
|
||||
( 1, 128, 128, 64, 32, 32),
|
||||
( 1, 256, 256, 64, 32, 32),
|
||||
( 1, 512, 512, 64, 32, 32),
|
||||
( 1, 1024, 1024, 64, 32, 32),
|
||||
( 1, 2048, 2048, 64, 32, 32),
|
||||
( 1, 4096, 4096, 64, 32, 32),
|
||||
)
|
||||
|
||||
shapes_80 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 80, 32, 32),
|
||||
( 1, 2048, 2048, 80, 32, 32),
|
||||
( 1, 4096, 4096, 80, 32, 32),
|
||||
)
|
||||
|
||||
shapes_128 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 128, 32, 32),
|
||||
( 1, 2048, 2048, 128, 32, 32),
|
||||
( 1, 4096, 4096, 128, 32, 32),
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
shapes = shapes_64 + shapes_80 + shapes_128
|
||||
|
||||
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
||||
|
||||
for dtype in dtypes:
|
||||
for transpose in transposes:
|
||||
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
|
||||
)
|
||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||
t_str = 1 if transpose else 0
|
||||
print(
|
||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
|
@@ -4,42 +4,51 @@ import math
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
L = 1024
|
||||
L = 16384
|
||||
H = 32
|
||||
H_k = 32 // 4
|
||||
H_k = H // 4
|
||||
D = 128
|
||||
dtype = mx.float16
|
||||
loops = 10
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, D)
|
||||
def _sdpa(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, D)
|
||||
|
||||
for i in range(loops):
|
||||
q = _sdpa(q, k, v)
|
||||
return q
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
for i in range(loops):
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
return q
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
mx.eval(q, k, v)
|
||||
time_fn(attention, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
mx.eval(q, k, v)
|
||||
time_fn(sdpa, q, k, v)
|
||||
|
||||
|
@@ -60,6 +60,7 @@ html_theme_options = {
|
||||
},
|
||||
}
|
||||
|
||||
html_favicon = html_theme_options["logo"]["image_light"]
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
|
@@ -1,3 +1,5 @@
|
||||
.. _custom_metal_kernels:
|
||||
|
||||
Custom Metal Kernels
|
||||
====================
|
||||
|
||||
@@ -76,6 +78,10 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
|
@@ -494,7 +494,7 @@ below.
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel declaration at axpby.metal
|
||||
@@ -509,14 +509,14 @@ below.
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
compute_encoder.set_bytes(alpha_, 3);
|
||||
compute_encoder.set_bytes(beta_, 4);
|
||||
|
||||
// Encode shape, strides and ndim
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
// threads in any given threadgroup is not higher than the max allowed
|
||||
@@ -530,7 +530,7 @@ below.
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
|
@@ -209,7 +209,7 @@ Metal library by run-time compiling kernels the first time they are used in MLX
|
||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists accross reboots.
|
||||
Metal kernel cache persists across reboots.
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
@@ -161,7 +161,7 @@ A naive way to add the elements from two sets of vectors is with a loop:
|
||||
ys = mx.random.uniform(shape=(100, 4096))
|
||||
|
||||
def naive_add(xs, ys):
|
||||
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
|
||||
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
|
||||
|
||||
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||
|
||||
@@ -169,7 +169,7 @@ Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||
|
||||
# Vectorize over the second dimension of x and the
|
||||
# first dimension of y
|
||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
|
||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
|
||||
|
||||
The ``in_axes`` parameter can be used to specify which dimensions of the
|
||||
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
||||
@@ -184,8 +184,8 @@ Let's time these two different versions:
|
||||
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
||||
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
||||
|
||||
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
|
||||
vectorized version takes only ``0.025`` seconds, more than ten times faster.
|
||||
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
|
||||
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
|
||||
|
||||
Of course, this operation is quite contrived. A better approach is to simply do
|
||||
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
||||
|
@@ -77,7 +77,7 @@ from the GPU. Performing bounds checking for array indices before launching the
|
||||
kernel would be extremely inefficient.
|
||||
|
||||
Indexing with boolean masks is something that MLX may support in the future. In
|
||||
general, MLX has limited support for operations for which outputs
|
||||
general, MLX has limited support for operations for which output
|
||||
*shapes* are dependent on input *data*. Other examples of these types of
|
||||
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||
single input version of :func:`numpy.where`.
|
||||
|
@@ -109,7 +109,7 @@ Here is a concrete example:
|
||||
|
||||
An important behavior to be aware of is when the graph will be implicitly
|
||||
evaluated. Anytime you ``print`` an array, convert it to an
|
||||
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
|
||||
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
|
||||
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
||||
saving functions) will also evaluate the array.
|
||||
|
||||
|
@@ -257,7 +257,7 @@ void Axpby::eval_gpu(
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel declaration at axpby.metal
|
||||
@@ -272,15 +272,15 @@ void Axpby::eval_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
compute_encoder.set_bytes(alpha_, 3);
|
||||
compute_encoder.set_bytes(beta_, 4);
|
||||
|
||||
// Encode shape, strides and ndim if needed
|
||||
if (!contiguous_kernel) {
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
}
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
@@ -295,7 +295,7 @@ void Axpby::eval_gpu(
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#else // Metal is not available
|
||||
|
@@ -2,7 +2,6 @@
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T>
|
||||
@@ -60,4 +59,4 @@ template <typename T>
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
|
15
mlx.pc.in
15
mlx.pc.in
@@ -28,10 +28,19 @@ endif()
|
||||
if (@MLX_BUILD_METAL@)
|
||||
set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
|
||||
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
|
||||
set_and_check(MLX_INCLUDE_DIRS
|
||||
${MLX_INCLUDE_DIRS}
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
|
||||
)
|
||||
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
|
||||
else()
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set_target_properties(mlx PROPERTIES
|
||||
@@ -40,4 +49,4 @@ set_target_properties(mlx PROPERTIES
|
||||
)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
|
||||
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
|
||||
|
@@ -19,7 +19,7 @@ Buffer malloc(size_t size) {
|
||||
}
|
||||
|
||||
void free(Buffer buffer) {
|
||||
return allocator().free(buffer);
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||
|
@@ -214,6 +214,8 @@ array::~array() {
|
||||
if (do_detach) {
|
||||
for (auto& s : siblings()) {
|
||||
for (auto& ss : s.siblings()) {
|
||||
// Set to null here to avoid descending into array destructor
|
||||
// for siblings
|
||||
ss.array_desc_ = nullptr;
|
||||
}
|
||||
s.array_desc_->siblings.clear();
|
||||
@@ -271,6 +273,9 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
for (array& a : ad.inputs) {
|
||||
if (a.array_desc_) {
|
||||
input_map.insert({a.id(), a});
|
||||
for (auto& s : a.siblings()) {
|
||||
input_map.insert({s.id(), s});
|
||||
}
|
||||
}
|
||||
}
|
||||
ad.inputs.clear();
|
||||
@@ -289,6 +294,14 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
auto top = std::move(for_deletion.back());
|
||||
for_deletion.pop_back();
|
||||
append_deletable_inputs(*top);
|
||||
|
||||
// Clear out possible siblings to break circular references
|
||||
for (auto& s : top->siblings) {
|
||||
// Set to null here to avoid descending into top-level
|
||||
// array destructor for siblings
|
||||
s.array_desc_ = nullptr;
|
||||
}
|
||||
top->siblings.clear();
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -39,7 +39,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
return move_or_copy(in, out, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -58,12 +58,12 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
move_or_copy(in, out, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
move_or_copy(inputs[0], out);
|
||||
}
|
||||
|
||||
void CustomTransforms::eval(
|
||||
@@ -72,7 +72,7 @@ void CustomTransforms::eval(
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||
i++, j++) {
|
||||
outputs[i].copy_shared_buffer(inputs[j]);
|
||||
move_or_copy(inputs[j], outputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ void Depends::eval(
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
outputs[i].copy_shared_buffer(inputs[i]);
|
||||
move_or_copy(inputs[i], outputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ void Reshape::shared_buffer_reshape(
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
move_or_copy(in, out, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Split::eval(
|
||||
@@ -263,7 +263,7 @@ std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
move_or_copy(inputs[0], out);
|
||||
}
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -297,7 +297,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
move_or_copy(in, out, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -279,7 +279,7 @@ void Compiled::eval_cpu(
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
||||
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
std::vector<void*> args;
|
||||
|
@@ -159,6 +159,17 @@ void Conjugate::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy(in, out, CopyType::General);
|
||||
}
|
||||
}
|
||||
|
||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@@ -606,7 +617,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
|
||||
in.flags().row_contiguous) {
|
||||
auto strides = in.strides();
|
||||
for (int i = 0; i < strides.size() - 1; ++i) {
|
||||
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
|
||||
strides[i] *= ibytes;
|
||||
strides[i] /= obytes;
|
||||
}
|
||||
|
@@ -2,7 +2,9 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -404,4 +406,103 @@ void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
transpose_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void quantize(
|
||||
const array& w_,
|
||||
array& out_,
|
||||
array& scales_,
|
||||
array& biases_,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool compute_scale_bias) {
|
||||
const T* w = w_.data<T>();
|
||||
T* scales = scales_.data<T>();
|
||||
T* biases = biases_.data<T>();
|
||||
auto out = out_.data<uint32_t>();
|
||||
|
||||
T n_bins = (1 << bits) - 1;
|
||||
T eps = 1e-7;
|
||||
int el_per_int = 32 / bits;
|
||||
int int_per_group = group_size / el_per_int;
|
||||
size_t n_groups = w_.size() / group_size;
|
||||
|
||||
for (size_t i = 0; i < n_groups; ++i) {
|
||||
size_t w_idx = i * group_size;
|
||||
if (compute_scale_bias) {
|
||||
T w_min = std::numeric_limits<float>::infinity();
|
||||
T w_max = -w_min;
|
||||
for (int j = 0; j < group_size; ++j) {
|
||||
w_max = std::max(w_max, w[w_idx + j]);
|
||||
w_min = std::min(w_min, w[w_idx + j]);
|
||||
}
|
||||
bool mask = std::abs(w_min) > std::abs(w_max);
|
||||
T scale = std::max(T((w_max - w_min) / n_bins), eps);
|
||||
scale = mask ? scale : -scale;
|
||||
|
||||
auto edge = mask ? w_min : w_max;
|
||||
auto q0 = std::rint(edge / scale);
|
||||
if (q0 == 0) {
|
||||
scales[i] = scale;
|
||||
biases[i] = 0;
|
||||
} else {
|
||||
scales[i] = edge / q0;
|
||||
biases[i] = edge;
|
||||
}
|
||||
}
|
||||
size_t out_idx = i * int_per_group;
|
||||
for (int j = 0; j < int_per_group; ++j) {
|
||||
uint32_t out_el = 0;
|
||||
for (int k = 0; k < el_per_int; ++k) {
|
||||
T w_el = w[w_idx + j * el_per_int + k];
|
||||
w_el = std::rint((w_el - biases[i]) / scales[i]);
|
||||
w_el = std::min(std::max(w_el, T(0)), n_bins);
|
||||
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
||||
}
|
||||
out[out_idx + j] = out_el;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
bool compute_scale_bias = inputs.size() == 1;
|
||||
|
||||
auto ensure_row_contiguous = [](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
auto w = ensure_row_contiguous(inputs[0]);
|
||||
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& scales =
|
||||
compute_scale_bias ? outputs[1] : const_cast<array&>(inputs[1]);
|
||||
auto& biases =
|
||||
compute_scale_bias ? outputs[2] : const_cast<array&>(inputs[2]);
|
||||
if (compute_scale_bias) {
|
||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
}
|
||||
if (w.dtype() == float16) {
|
||||
quantize<float16_t>(
|
||||
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
|
||||
} else if (w.dtype() == bfloat16) {
|
||||
quantize<bfloat16_t>(
|
||||
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
|
||||
} else if (w.dtype() == float32) {
|
||||
quantize<float>(
|
||||
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -120,48 +120,56 @@ struct MinReduce {
|
||||
};
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_out(
|
||||
void reduce_dispatch_and_or(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
switch (rtype) {
|
||||
case Reduce::And: {
|
||||
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||
break;
|
||||
if (rtype == Reduce::And) {
|
||||
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||
} else {
|
||||
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_sum_prod(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Sum) {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, op);
|
||||
}
|
||||
case Reduce::Or: {
|
||||
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||
break;
|
||||
}
|
||||
case Reduce::Sum: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||
if (out.dtype() == int32) {
|
||||
// special case since the input type can be bool
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, op);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Prod: {
|
||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||
} else {
|
||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 1, op);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 1, op);
|
||||
break;
|
||||
}
|
||||
case Reduce::Max: {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
break;
|
||||
}
|
||||
case Reduce::Min: {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_min_max(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Max) {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
} else {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void nd_loop(
|
||||
@@ -190,46 +198,114 @@ void nd_loop(
|
||||
void Reduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_out<bool>(in, out, reduce_type_, axes_);
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
case Reduce::Or: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
|
||||
}
|
||||
case Reduce::Sum:
|
||||
case Reduce::Prod: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_out<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_out<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_out<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_out<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_out<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_out<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_out<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_out<complex64_t>(in, out, reduce_type_, axes_);
|
||||
}
|
||||
case Reduce::Max:
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -34,7 +34,7 @@ void shared_buffer_slice(
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
flags.contiguous = (no_bsx_size == data_size);
|
||||
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -4,6 +4,28 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void move_or_copy(const array& in, array& out) {
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void move_or_copy(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& strides,
|
||||
array::Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in, strides, flags, data_size, offset);
|
||||
} else {
|
||||
out.copy_shared_buffer(in, strides, flags, data_size, offset);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename StrideT>
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
|
||||
collapse_contiguous_dims_impl(
|
||||
|
@@ -178,4 +178,13 @@ inline bool is_donatable(const array& in, const array& out) {
|
||||
in.buffer_size() <= out.nbytes() + donation_extra;
|
||||
}
|
||||
|
||||
void move_or_copy(const array& in, array& out);
|
||||
void move_or_copy(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& strides,
|
||||
array::Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -14,20 +14,27 @@ function(make_jit_source SRC_FILE)
|
||||
COMMAND
|
||||
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
|
||||
${SRC_FILE} "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
|
||||
${SRC_FILE}
|
||||
DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
|
||||
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
|
||||
add_dependencies(mlx ${SRC_NAME})
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
|
||||
endfunction(make_jit_source)
|
||||
|
||||
make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h)
|
||||
make_jit_source(
|
||||
utils
|
||||
kernels/jit/bf16.h
|
||||
kernels/metal_3_0/bf16.h
|
||||
kernels/metal_3_1/bf16.h
|
||||
kernels/bf16_math.h
|
||||
kernels/complex.h
|
||||
kernels/defines.h)
|
||||
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
||||
make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
||||
make_jit_source(scatter)
|
||||
make_jit_source(gather)
|
||||
make_jit_source(scatter kernels/indexing.h)
|
||||
make_jit_source(gather kernels/indexing.h)
|
||||
make_jit_source(hadamard)
|
||||
|
||||
if(MLX_METAL_JIT)
|
||||
|
@@ -242,6 +242,9 @@ void MetalAllocator::clear_cache() {
|
||||
|
||||
void MetalAllocator::free(Buffer buffer) {
|
||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||
if (buf == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::unique_lock lk(mutex_);
|
||||
residency_set_.erase(buf);
|
||||
active_memory_ -= buf->length();
|
||||
|
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
@@ -23,37 +22,37 @@ std::string get_kernel_name(
|
||||
BinaryOpType bopt,
|
||||
const std::string& op,
|
||||
const array& a,
|
||||
bool use_2d,
|
||||
bool large,
|
||||
int ndim,
|
||||
int work_per_thread) {
|
||||
std::ostringstream kname;
|
||||
std::string kname;
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
kname = "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << (use_2d ? "sv2" : "sv");
|
||||
kname = (large ? "sv2" : "sv");
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << (use_2d ? "vs2" : "vs");
|
||||
kname = (large ? "vs2" : "vs");
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << (use_2d ? "vv2" : "vv");
|
||||
kname = (large ? "vv2" : "vv");
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
kname = "g";
|
||||
if (ndim <= 3) {
|
||||
kname << ndim;
|
||||
kname += std::to_string(ndim);
|
||||
} else {
|
||||
kname << "n";
|
||||
if (work_per_thread > 1) {
|
||||
kname << work_per_thread;
|
||||
}
|
||||
concatenate(kname, "n", std::to_string(work_per_thread));
|
||||
}
|
||||
if (large) {
|
||||
kname += "large";
|
||||
}
|
||||
break;
|
||||
}
|
||||
kname << "_" << op << type_to_name(a);
|
||||
return kname.str();
|
||||
concatenate(kname, "_", op, type_to_name(a));
|
||||
return kname;
|
||||
}
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
@@ -82,18 +81,23 @@ void binary_op_gpu_inplace(
|
||||
};
|
||||
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
|
||||
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
bool large = out.data_size() > UINT32_MAX;
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1;
|
||||
int work_per_thread;
|
||||
if (bopt == BinaryOpType::General) {
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
work_per_thread = 1;
|
||||
}
|
||||
std::string kernel_name =
|
||||
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
|
||||
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = outputs.size() == 2
|
||||
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
|
||||
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// - If a is donated it goes to the first output
|
||||
// - If b is donated it goes to the first output if a was not donated
|
||||
@@ -110,6 +114,7 @@ void binary_op_gpu_inplace(
|
||||
compute_encoder.set_output_array(outputs[1], arg_idx++);
|
||||
}
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (bopt == BinaryOpType::General) {
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
@@ -117,39 +122,33 @@ void binary_op_gpu_inplace(
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
|
||||
compute_encoder.set_vector_bytes(shape, arg_idx++);
|
||||
compute_encoder.set_vector_bytes(strides_a, arg_idx++);
|
||||
compute_encoder.set_vector_bytes(strides_b, arg_idx++);
|
||||
compute_encoder.set_bytes<int>(ndim, arg_idx++);
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(
|
||||
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder.set_vector_bytes(strides_a, arg_idx++);
|
||||
compute_encoder.set_vector_bytes(strides_b, arg_idx++);
|
||||
}
|
||||
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <iostream> //TODO
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
@@ -11,12 +12,12 @@
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
using namespace fmt::literals;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int WORK_PER_THREAD = 4;
|
||||
|
||||
inline void build_kernel(
|
||||
std::ostream& os,
|
||||
std::string& os,
|
||||
const std::string& kernel_name,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
@@ -41,8 +42,8 @@ inline void build_kernel(
|
||||
int cnt = 0;
|
||||
|
||||
// Start the kernel
|
||||
os << "[[host_name(\"" << kernel_name << "\")]]\n"
|
||||
<< "[[kernel]] void " << kernel_name << "(\n";
|
||||
os += fmt::format(
|
||||
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
|
||||
|
||||
// Add the input arguments
|
||||
for (auto& x : inputs) {
|
||||
@@ -54,51 +55,61 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Scalars and contiguous need no strides
|
||||
if (is_scalar(x) || contiguous) {
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]],\n";
|
||||
} else {
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
add_indices = true;
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]],\n";
|
||||
}
|
||||
os += fmt::format(
|
||||
" device const {0}* {1} [[buffer({2})]],\n",
|
||||
get_type_string(x.dtype()),
|
||||
xname,
|
||||
cnt++);
|
||||
}
|
||||
|
||||
if (add_indices) {
|
||||
os << " constant const size_t* in_strides [[buffer(" << cnt++
|
||||
<< ")]],\n";
|
||||
os += fmt::format(
|
||||
" constant const size_t* in_strides [[buffer({0})]],\n", cnt++);
|
||||
}
|
||||
|
||||
// Add the output arguments
|
||||
for (auto& x : outputs) {
|
||||
os << " device " << get_type_string(x.dtype()) << "* "
|
||||
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]],\n";
|
||||
os += fmt::format(
|
||||
" device {0}* {1} [[buffer({2})]],\n",
|
||||
get_type_string(x.dtype()),
|
||||
namer.get_name(x),
|
||||
cnt++);
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os << " constant const size_t* output_strides [[buffer(" << cnt++
|
||||
<< ")]],\n"
|
||||
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],\n";
|
||||
os += fmt::format(
|
||||
" constant const size_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||
os += fmt::format(
|
||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||
}
|
||||
if (dynamic_dims) {
|
||||
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],\n";
|
||||
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
|
||||
}
|
||||
|
||||
// The thread index in the whole grid
|
||||
os << " uint3 pos [[thread_position_in_grid]],\n"
|
||||
<< " uint3 grid [[threads_per_grid]]) {\n";
|
||||
os += " uint3 pos [[thread_position_in_grid]],\n";
|
||||
os += " uint3 grid [[threads_per_grid]]) {\n";
|
||||
|
||||
if (use_big_index) {
|
||||
std::string idx_type = use_big_index ? "size_t" : "uint";
|
||||
if (contiguous && use_big_index) {
|
||||
// This is only used for contiguous kernels which don't have
|
||||
// a third grid dimension
|
||||
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
|
||||
os += " size_t index = pos.x + grid.x * size_t(pos.y);\n";
|
||||
} else if (work_per_thread > 1) {
|
||||
os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n"
|
||||
<< " int xshape = output_shape["
|
||||
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
|
||||
<< " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
|
||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||
os += fmt::format(
|
||||
" int xshape = output_shape[{0}];\n",
|
||||
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
|
||||
os += fmt::format(
|
||||
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||
idx_type);
|
||||
} else {
|
||||
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
|
||||
os += fmt::format(
|
||||
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||
idx_type);
|
||||
}
|
||||
|
||||
// Read constant / contiguous inputs in tmps
|
||||
@@ -109,16 +120,19 @@ inline void build_kernel(
|
||||
|
||||
if (is_constant(x)) {
|
||||
auto type_str = get_type_string(x.dtype());
|
||||
os << " auto tmp_" << xname << " = static_cast<"
|
||||
<< get_type_string(x.dtype()) << ">(";
|
||||
print_constant(os, x);
|
||||
os << ");\n";
|
||||
std::ostringstream ss;
|
||||
print_constant(ss, x);
|
||||
os += fmt::format(
|
||||
" auto tmp_{0} = static_cast<{1}>({2});\n",
|
||||
xname,
|
||||
get_type_string(x.dtype()),
|
||||
ss.str());
|
||||
} else if (is_scalar(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[0];\n";
|
||||
os += fmt::format(
|
||||
" {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
|
||||
} else if (contiguous) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[index];\n";
|
||||
os += fmt::format(
|
||||
" {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
|
||||
} else {
|
||||
nc_inputs.push_back(x);
|
||||
}
|
||||
@@ -127,83 +141,98 @@ inline void build_kernel(
|
||||
// Initialize the indices for non-contiguous inputs
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& xname = namer.get_name(nc_inputs[i]);
|
||||
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
|
||||
if (ndim == 1) {
|
||||
int offset = i * ndim;
|
||||
os << " size_t index_" << xname << " = elem_to_loc_1(pos.x, "
|
||||
<< "in_strides[" << offset << "]);\n";
|
||||
os += fmt::format(
|
||||
"elem_to_loc_1<size_t, uint>(pos.x, in_strides[{0}]);\n", offset);
|
||||
} else if (ndim == 2) {
|
||||
int offset = i * ndim;
|
||||
os << " size_t index_" << xname << " = elem_to_loc_2({pos.x, pos.y}, "
|
||||
<< "in_strides + " << offset << ");\n";
|
||||
os += fmt::format(
|
||||
"elem_to_loc_2<size_t, {0}>({{pos.x, pos.y}}, in_strides + {1});\n",
|
||||
idx_type,
|
||||
offset);
|
||||
} else if (ndim == 3) {
|
||||
int offset = i * ndim;
|
||||
os << " size_t index_" << xname << " = elem_to_loc_3(pos, "
|
||||
<< "in_strides + " << offset << ");\n";
|
||||
os += fmt::format(
|
||||
"elem_to_loc_3<size_t, {0}>(pos, in_strides + {1});\n",
|
||||
idx_type,
|
||||
offset);
|
||||
} else if (!dynamic_dims) {
|
||||
int offset = i * ndim;
|
||||
os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
|
||||
<< offset + ndim - 1 << "]"
|
||||
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
|
||||
int offset = (i + 1) * ndim;
|
||||
os += fmt::format(
|
||||
"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
|
||||
idx_type,
|
||||
offset - 1,
|
||||
offset - 2);
|
||||
} else {
|
||||
os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
|
||||
<< i << " + ndim - 1]"
|
||||
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
|
||||
os += fmt::format(
|
||||
"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
|
||||
idx_type,
|
||||
i);
|
||||
}
|
||||
}
|
||||
|
||||
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
|
||||
os << " uint zpos = pos.z;\n";
|
||||
os += " uint zpos = pos.z;\n";
|
||||
if (dynamic_dims) {
|
||||
os << " for (int d = ndim - 3; d >= 0; --d) {\n";
|
||||
os += " for (int d = ndim - 3; d >= 0; --d) {\n";
|
||||
} else {
|
||||
os << " for (int d = " << ndim - 3 << "; d >= 0; --d) {\n";
|
||||
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
|
||||
}
|
||||
os << " uint l = zpos % output_shape[d];\n";
|
||||
os += " uint l = zpos % output_shape[d];\n";
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& xname = namer.get_name(nc_inputs[i]);
|
||||
os << " index_" << xname << " += ";
|
||||
os += fmt::format(" index_{0} += ", xname);
|
||||
if (dynamic_dims) {
|
||||
os << "l * in_strides[" << i << " * ndim + d];\n";
|
||||
os +=
|
||||
fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
|
||||
} else {
|
||||
os << "l * in_strides[" << i * ndim << " + d];\n";
|
||||
os +=
|
||||
fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
|
||||
}
|
||||
}
|
||||
os << " zpos /= output_shape[d];\n }\n";
|
||||
os += " zpos /= output_shape[d];\n }\n";
|
||||
}
|
||||
|
||||
// Open per-thread loop
|
||||
if (work_per_thread > 1) {
|
||||
os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||
os +=
|
||||
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||
}
|
||||
|
||||
// Read non-contiguous inputs into tmps
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& x = nc_inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[index_" << xname << "];\n";
|
||||
os += fmt::format(
|
||||
" {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
|
||||
}
|
||||
|
||||
// Actually write the computation
|
||||
for (auto& x : tape) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
|
||||
<< " = ";
|
||||
os += fmt::format(
|
||||
" {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
|
||||
if (is_static_cast(x.primitive())) {
|
||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||
<< namer.get_name(x.inputs()[0]) << ");\n";
|
||||
os += fmt::format(
|
||||
"static_cast<{0}>(tmp_{1});\n",
|
||||
get_type_string(x.dtype()),
|
||||
namer.get_name(x.inputs()[0]));
|
||||
} else {
|
||||
x.primitive().print(os);
|
||||
os << "()(";
|
||||
std::ostringstream ss;
|
||||
x.primitive().print(ss);
|
||||
os += ss.str();
|
||||
os += "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
||||
}
|
||||
os << "tmp_" << namer.get_name(x.inputs().back()) << ");\n";
|
||||
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
|
||||
}
|
||||
}
|
||||
|
||||
// Write the outputs from tmps
|
||||
for (auto& x : outputs) {
|
||||
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
|
||||
<< ";\n";
|
||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||
}
|
||||
// Increment indices and close per thread loop
|
||||
if (work_per_thread > 1) {
|
||||
@@ -211,18 +240,18 @@ inline void build_kernel(
|
||||
auto& x = nc_inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
if (!dynamic_dims) {
|
||||
os << " index_" << xname << " += "
|
||||
<< "in_strides[" << i * ndim + ndim - 1 << "];\n";
|
||||
os += fmt::format(
|
||||
" index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
|
||||
} else {
|
||||
os << " index_" << xname << " += "
|
||||
<< "in_strides[" << i << " * ndim + ndim - 1];\n";
|
||||
os += fmt::format(
|
||||
" index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
|
||||
}
|
||||
}
|
||||
os << " index++;\n }\n";
|
||||
os += " index++;\n }\n";
|
||||
}
|
||||
|
||||
// Finish the kernel
|
||||
os << "}\n";
|
||||
os += "}\n";
|
||||
|
||||
if (cnt > 31) {
|
||||
std::ostringstream msg;
|
||||
@@ -246,9 +275,9 @@ void Compiled::eval_gpu(
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto lib = d.get_library(kernel_lib_, [&]() {
|
||||
std::ostringstream kernel;
|
||||
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
|
||||
<< metal::ternary_ops();
|
||||
std::string kernel = metal::utils();
|
||||
concatenate(
|
||||
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous",
|
||||
@@ -261,7 +290,7 @@ void Compiled::eval_gpu(
|
||||
/* dynamic_dims = */ false);
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous_big",
|
||||
kernel_lib_ + "_contiguous_large",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
@@ -282,7 +311,21 @@ void Compiled::eval_gpu(
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ i > 3 ? WORK_PER_THREAD : 1);
|
||||
/* work_per_thread = */ i > 3 ? 2 : 1);
|
||||
if (i > 1) {
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_strided_" + std::to_string(i) + "_large",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ true,
|
||||
/* work_per_thread = */ i > 3 ? 4 : 1);
|
||||
}
|
||||
}
|
||||
build_kernel(
|
||||
kernel,
|
||||
@@ -295,13 +338,25 @@ void Compiled::eval_gpu(
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ WORK_PER_THREAD);
|
||||
return kernel.str();
|
||||
/* work_per_thread = */ 2);
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_strided_dynamic_large",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true,
|
||||
/* use_big_index = */ true,
|
||||
/* work_per_thread = */ 4);
|
||||
return kernel;
|
||||
});
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& output_shape = outputs[0].shape();
|
||||
bool contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||
auto contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
@@ -349,13 +404,19 @@ void Compiled::eval_gpu(
|
||||
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
|
||||
}
|
||||
|
||||
bool use_2d = false;
|
||||
bool large;
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
for (auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
use_2d = (max_size > UINT32_MAX);
|
||||
large = (max_size > UINT32_MAX);
|
||||
} else {
|
||||
size_t max_size = 0;
|
||||
for (auto& o : outputs) {
|
||||
max_size = std::max(max_size, o.size());
|
||||
}
|
||||
large = (max_size > UINT32_MAX);
|
||||
}
|
||||
|
||||
// Get the kernel from the lib
|
||||
@@ -368,12 +429,13 @@ void Compiled::eval_gpu(
|
||||
} else {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
}
|
||||
} else if (use_2d) {
|
||||
kernel_name += "_big";
|
||||
}
|
||||
if (large) {
|
||||
kernel_name += "_large";
|
||||
}
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Put the inputs in
|
||||
int cnt = 0;
|
||||
@@ -394,8 +456,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
if (!in_strides.empty()) {
|
||||
compute_encoder->setBytes(
|
||||
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
|
||||
compute_encoder.set_vector_bytes(in_strides, cnt++);
|
||||
}
|
||||
|
||||
compiled_allocate_outputs(
|
||||
@@ -408,30 +469,30 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Put the output shape and strides in
|
||||
if (!contiguous) {
|
||||
compute_encoder->setBytes(
|
||||
strides[0].data(), strides[0].size() * sizeof(size_t), cnt++);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
|
||||
compute_encoder.set_vector_bytes(strides[0], cnt++);
|
||||
compute_encoder.set_vector_bytes(shape, cnt++);
|
||||
}
|
||||
|
||||
// Put the number of dims in if it is dynamic
|
||||
if (dynamic) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), cnt++);
|
||||
compute_encoder.set_bytes(ndim, cnt++);
|
||||
}
|
||||
|
||||
// Launch the kernel
|
||||
if (contiguous) {
|
||||
size_t nthreads = outputs[0].data_size();
|
||||
MTL::Size grid_dims = use_2d
|
||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
MTL::Size grid_dims = large
|
||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = outputs[0].size() / (dim0 * dim1);
|
||||
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
|
||||
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
int pow2;
|
||||
@@ -444,7 +505,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -44,23 +44,24 @@ void explicit_gemm_conv_ND_gpu(
|
||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(in_unfolded, 1);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
compute_encoder.set_bytes(conv_params, 2);
|
||||
|
||||
// Launch unfolding kernel
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
size_t tgp_x = std::min(conv_params.C, 64);
|
||||
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
||||
int tgp_y = 256 / tgp_x;
|
||||
size_t tgp_y = 256 / tgp_x;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
|
||||
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
// Reshape weight
|
||||
std::vector<int> wt_reshape{implicit_K, implicit_N};
|
||||
@@ -122,23 +123,24 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
<< N;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(in_unfolded, 1);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
compute_encoder.set_bytes(conv_params, 2);
|
||||
|
||||
// Launch unfolding kernel
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
size_t tgp_x = std::min(conv_params.C, 64);
|
||||
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
||||
int tgp_y = 256 / tgp_x;
|
||||
size_t tgp_y = 256 / tgp_x;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
|
||||
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
// Transpose kernel weights so that we can slice them by contiguous chunks
|
||||
// of channel groups.
|
||||
@@ -237,7 +239,7 @@ void slow_conv_2D_gpu(
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
|
||||
|
||||
@@ -252,8 +254,8 @@ void slow_conv_2D_gpu(
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
@@ -352,7 +354,7 @@ void implicit_gemm_conv_2D_gpu(
|
||||
wn,
|
||||
n_channel_specialization,
|
||||
small_filter);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
int tile = 1 << swizzle_log;
|
||||
@@ -368,11 +370,11 @@ void implicit_gemm_conv_2D_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
compute_encoder.set_bytes(gemm_params, 4);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_general_gpu(
|
||||
@@ -506,7 +508,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel =
|
||||
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
int tile = 1 << swizzle_log;
|
||||
@@ -523,17 +525,15 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
|
||||
compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5);
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
compute_encoder.set_bytes(gemm_params, 4);
|
||||
compute_encoder.set_bytes(jump_params, 5);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6);
|
||||
compute_encoder->setBytes(
|
||||
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
|
||||
compute_encoder.set_vector_bytes(base_h, 6);
|
||||
compute_encoder.set_vector_bytes(base_w, 7);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void winograd_conv_2D_gpu(
|
||||
@@ -622,18 +622,18 @@ void winograd_conv_2D_gpu(
|
||||
<< bc;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(wt, 0);
|
||||
compute_encoder.set_output_array(filt_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(&C_c, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&O_c, sizeof(int), 3);
|
||||
compute_encoder.set_bytes(C_c, 2);
|
||||
compute_encoder.set_bytes(O_c, 3);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do input transform
|
||||
@@ -650,18 +650,17 @@ void winograd_conv_2D_gpu(
|
||||
<< bc;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in_padded, 0);
|
||||
compute_encoder.set_output_array(inp_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
compute_encoder.set_bytes(conv_params_updated, 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do batched gemm
|
||||
@@ -698,18 +697,17 @@ void winograd_conv_2D_gpu(
|
||||
<< bc;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(out_wg, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
compute_encoder.set_bytes(conv_params_updated, 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -752,10 +750,6 @@ void conv_2D_gpu(
|
||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||
|
||||
bool inp_large = (conv_params.in_strides[0] >= 1ul << 18);
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 512;
|
||||
bool channels_med = (conv_params.C + conv_params.O) >= 256;
|
||||
|
||||
if (groups > 1) {
|
||||
const int C_per_group = conv_params.C / groups;
|
||||
const int O_per_group = conv_params.O / groups;
|
||||
@@ -769,10 +763,13 @@ void conv_2D_gpu(
|
||||
}
|
||||
|
||||
// Direct to winograd conv
|
||||
bool inp_large =
|
||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
||||
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
|
||||
(channels_large || (channels_med && inp_large))) {
|
||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
||||
channels_large) {
|
||||
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
}
|
||||
|
||||
|
@@ -74,44 +74,46 @@ void copy_gpu_inplace(
|
||||
};
|
||||
auto [shape, strides_in_, strides_out_] = maybe_collapse();
|
||||
int ndim = shape.size();
|
||||
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
bool large;
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
// Allow for negative strides
|
||||
large = out.data_size() > INT32_MAX;
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
}
|
||||
auto& d = metal::device(s.device);
|
||||
int work_per_thread = 1;
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kname << (use_2d ? "s2" : "s");
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kname << (use_2d ? "v2" : "v");
|
||||
break;
|
||||
case CopyType::General:
|
||||
kname << "g";
|
||||
break;
|
||||
case CopyType::GeneralGeneral:
|
||||
kname << "gg";
|
||||
break;
|
||||
}
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
} else {
|
||||
work_per_thread = 4;
|
||||
kname << "n4";
|
||||
}
|
||||
}
|
||||
kname << "_copy";
|
||||
kname << type_to_name(in) << type_to_name(out);
|
||||
kernel_name = kname.str();
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kernel_name = (large ? "s2" : "s");
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kernel_name = (large ? "v2" : "v");
|
||||
break;
|
||||
case CopyType::General:
|
||||
kernel_name = "g";
|
||||
break;
|
||||
case CopyType::GeneralGeneral:
|
||||
kernel_name = "gg";
|
||||
break;
|
||||
}
|
||||
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
} else {
|
||||
work_per_thread = large ? 4 : 2;
|
||||
concatenate(kernel_name, "n", std::to_string(work_per_thread));
|
||||
}
|
||||
if (large) {
|
||||
kernel_name += "large";
|
||||
}
|
||||
}
|
||||
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
||||
auto kernel = get_copy_kernel(d, kernel_name, in, out);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
bool donate_in = in.data_shared_ptr() == nullptr;
|
||||
|
||||
inp_offset *= size_of(in.dtype());
|
||||
@@ -120,15 +122,16 @@ void copy_gpu_inplace(
|
||||
compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset);
|
||||
compute_encoder.set_output_array(out, 1, out_offset);
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
|
||||
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
|
||||
if (ndim > 3) {
|
||||
set_vector_bytes(compute_encoder, shape, ndim, 2);
|
||||
compute_encoder.set_vector_bytes(shape, ndim, 2);
|
||||
}
|
||||
set_vector_bytes(compute_encoder, strides_in, ndim, 3);
|
||||
compute_encoder.set_vector_bytes(strides_in, ndim, 3);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
|
||||
compute_encoder.set_vector_bytes(strides_out, ndim, 4);
|
||||
}
|
||||
|
||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
@@ -140,29 +143,27 @@ void copy_gpu_inplace(
|
||||
int rest = data_size / (dim0 * dim1);
|
||||
|
||||
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 5);
|
||||
compute_encoder.set_bytes(ndim, 5);
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
}
|
||||
|
||||
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
|
||||
}
|
||||
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,26 +195,26 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
bool large = out.data_size() > UINT32_MAX;
|
||||
auto& d = metal::device(s.device);
|
||||
std::string kernel_name = std::string(use_2d ? "s2" : "s") + "_copy" +
|
||||
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
|
||||
type_to_name(val) + type_to_name(out);
|
||||
auto kernel = get_copy_kernel(d, kernel_name, val, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(val, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -43,7 +43,7 @@ void CustomKernel::eval_gpu(
|
||||
d.get_library(lib_name, [this] { return metal::utils() + source_; });
|
||||
auto kernel = d.get_kernel(name_, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
int index = 0;
|
||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
||||
const array& in = checked_inputs[i];
|
||||
@@ -53,15 +53,15 @@ void CustomKernel::eval_gpu(
|
||||
if (in.ndim() > 0) {
|
||||
int ndim = in.ndim();
|
||||
if (shape_info.shape) {
|
||||
set_vector_bytes(compute_encoder, in.shape(), ndim, index);
|
||||
compute_encoder.set_vector_bytes(in.shape(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
set_vector_bytes(compute_encoder, in.strides(), ndim, index);
|
||||
compute_encoder.set_vector_bytes(in.strides(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), index);
|
||||
compute_encoder.set_bytes(ndim, index);
|
||||
index++;
|
||||
}
|
||||
}
|
||||
@@ -72,10 +72,11 @@ void CustomKernel::eval_gpu(
|
||||
}
|
||||
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
MTL::Size group_dims = MTL::Size(tx, ty, tz);
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
MTL::Size group_dims =
|
||||
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
||||
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
@@ -23,14 +23,18 @@ constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
constexpr auto get_metal_version() {
|
||||
#if (MLX_METAL_VERSION >= 320)
|
||||
return MTL::LanguageVersion3_2;
|
||||
#elif (MLX_METAL_VERSION >= 310)
|
||||
return MTL::LanguageVersion3_1;
|
||||
#else
|
||||
return MTL::LanguageVersion3_0;
|
||||
#endif
|
||||
auto get_metal_version() {
|
||||
auto get_metal_version_ = []() {
|
||||
if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
|
||||
return MTL::LanguageVersion3_2;
|
||||
} else if (__builtin_available(macOS 14, iOS 17, tvOS 17, visionOS 1, *)) {
|
||||
return MTL::LanguageVersion3_1;
|
||||
} else {
|
||||
return MTL::LanguageVersion3_0;
|
||||
}
|
||||
};
|
||||
static auto metal_version_ = get_metal_version_();
|
||||
return metal_version_;
|
||||
}
|
||||
|
||||
auto load_device() {
|
||||
@@ -136,13 +140,8 @@ void CommandEncoder::set_input_array(
|
||||
int64_t offset /* = 0 */) {
|
||||
all_inputs_.insert(a.buffer().ptr());
|
||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
|
||||
// Insert a barrier
|
||||
enc_->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs_.erase(it);
|
||||
}
|
||||
needs_barrier_ =
|
||||
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
@@ -161,19 +160,32 @@ void CommandEncoder::set_output_array(
|
||||
if (concurrent_) {
|
||||
concurrent_outputs_.insert(buf);
|
||||
} else {
|
||||
outputs_.insert(buf);
|
||||
next_outputs_.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreadgroups(
|
||||
void CommandEncoder::maybeInsertBarrier() {
|
||||
if (needs_barrier_) {
|
||||
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
|
||||
needs_barrier_ = false;
|
||||
prev_outputs_ = std::move(next_outputs_);
|
||||
} else {
|
||||
prev_outputs_.insert(next_outputs_.begin(), next_outputs_.end());
|
||||
}
|
||||
next_outputs_.clear();
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatch_threadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
enc_->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreads(
|
||||
void CommandEncoder::dispatch_threads(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
enc_->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
@@ -181,6 +193,7 @@ Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
library_map_ = {{"mlx", load_library(device_)}};
|
||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||
}
|
||||
|
||||
Device::~Device() {
|
||||
@@ -289,7 +302,7 @@ void Device::end_encoding(int index) {
|
||||
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
|
||||
// If we've already waited on a fence, don't wait on it again.
|
||||
if (waiting_on.find(it->second) == waiting_on.end()) {
|
||||
enc->waitForFence(it->second->fence);
|
||||
enc.wait_for_fence(it->second->fence);
|
||||
waiting_on.insert(it->second);
|
||||
}
|
||||
}
|
||||
@@ -298,7 +311,7 @@ void Device::end_encoding(int index) {
|
||||
stream.outputs[out] = stream.fence;
|
||||
}
|
||||
}
|
||||
enc->updateFence(stream.fence->fence);
|
||||
enc.update_fence(stream.fence->fence);
|
||||
stream.buffer->addCompletedHandler(
|
||||
[&stream,
|
||||
waiting_on = std::move(waiting_on),
|
||||
|
@@ -49,7 +49,7 @@ struct CommandEncoder {
|
||||
}
|
||||
~ConcurrentContext() {
|
||||
enc.concurrent_ = false;
|
||||
enc.outputs_.insert(
|
||||
enc.prev_outputs_.insert(
|
||||
enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
|
||||
enc.concurrent_outputs_.clear();
|
||||
}
|
||||
@@ -58,14 +58,42 @@ struct CommandEncoder {
|
||||
CommandEncoder& enc;
|
||||
};
|
||||
|
||||
MTL::ComputeCommandEncoder* operator->() {
|
||||
return enc_;
|
||||
}
|
||||
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0);
|
||||
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void maybeInsertBarrier();
|
||||
|
||||
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
|
||||
enc_->setComputePipelineState(kernel);
|
||||
}
|
||||
|
||||
void wait_for_fence(MTL::Fence* fence) {
|
||||
enc_->waitForFence(fence);
|
||||
}
|
||||
|
||||
void update_fence(MTL::Fence* fence) {
|
||||
enc_->updateFence(fence);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
|
||||
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
}
|
||||
template <typename T>
|
||||
void set_vector_bytes(const std::vector<T>& vec, int idx) {
|
||||
return set_vector_bytes(vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_bytes(const T* v, int n, int idx) {
|
||||
return enc_->setBytes(v, n * sizeof(T), idx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_bytes(const T& v, int idx) {
|
||||
return enc_->setBytes(&v, sizeof(T), idx);
|
||||
}
|
||||
|
||||
ConcurrentContext start_concurrent() {
|
||||
return ConcurrentContext(*this);
|
||||
@@ -84,8 +112,10 @@ struct CommandEncoder {
|
||||
|
||||
private:
|
||||
MTL::ComputeCommandEncoder* enc_;
|
||||
bool needs_barrier_{false};
|
||||
bool concurrent_{false};
|
||||
std::unordered_set<MTL::Resource*> outputs_;
|
||||
std::unordered_set<MTL::Resource*> prev_outputs_;
|
||||
std::unordered_set<MTL::Resource*> next_outputs_;
|
||||
std::unordered_set<MTL::Resource*> concurrent_outputs_;
|
||||
std::unordered_set<const void*> all_inputs_;
|
||||
std::unordered_set<const void*> all_outputs_;
|
||||
@@ -136,6 +166,10 @@ class Device {
|
||||
return device_;
|
||||
};
|
||||
|
||||
const std::string& get_architecture() {
|
||||
return arch_;
|
||||
}
|
||||
|
||||
void new_queue(int index);
|
||||
MTL::CommandBuffer* get_command_buffer(int index);
|
||||
int get_command_buffer_ops(int index);
|
||||
@@ -228,6 +262,7 @@ class Device {
|
||||
std::shared_mutex library_mtx_;
|
||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||
const MTL::ResidencySet* residency_set_{nullptr};
|
||||
std::string arch_;
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device);
|
||||
|
@@ -699,7 +699,7 @@ void fft_op(
|
||||
auto kernel =
|
||||
get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in_contiguous, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
@@ -711,9 +711,9 @@ void fft_op(
|
||||
|
||||
compute_encoder.set_input_array(w_q, 2); // w_q
|
||||
compute_encoder.set_input_array(w_k, 3); // w_k
|
||||
compute_encoder->setBytes(&n, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(n, 4);
|
||||
compute_encoder.set_bytes(plan.bluestein_n, 5);
|
||||
compute_encoder.set_bytes(total_batch_size, 6);
|
||||
} else if (plan.rader_n > 1) {
|
||||
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
|
||||
copies.push_back(b_q);
|
||||
@@ -723,22 +723,22 @@ void fft_op(
|
||||
compute_encoder.set_input_array(b_q, 2);
|
||||
compute_encoder.set_input_array(g_q, 3);
|
||||
compute_encoder.set_input_array(g_minus_q, 4);
|
||||
compute_encoder->setBytes(&n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
|
||||
compute_encoder.set_bytes(n, 5);
|
||||
compute_encoder.set_bytes(total_batch_size, 6);
|
||||
compute_encoder.set_bytes(plan.rader_n, 7);
|
||||
} else if (four_step_params.required) {
|
||||
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
|
||||
compute_encoder.set_bytes(four_step_params.n1, 2);
|
||||
compute_encoder.set_bytes(four_step_params.n2, 3);
|
||||
compute_encoder.set_bytes(total_batch_size, 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(&n, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
|
||||
compute_encoder.set_bytes(n, 2);
|
||||
compute_encoder.set_bytes(total_batch_size, 3);
|
||||
}
|
||||
|
||||
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
|
||||
auto grid_dims =
|
||||
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
@@ -137,14 +137,14 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 2);
|
||||
compute_encoder.set_bytes(scale, 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
||||
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
};
|
||||
|
||||
if (m > 1) {
|
||||
|
@@ -53,27 +53,31 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
std::string lib_name;
|
||||
std::string kernel_name;
|
||||
bool large_index = nidx && inputs[1].size() > UINT32_MAX;
|
||||
bool large_src = src.size() > UINT32_MAX;
|
||||
bool large_out = out.size() > UINT32_MAX;
|
||||
bool large = large_index || large_src || large_out;
|
||||
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
|
||||
<< "_" << idx_ndim;
|
||||
lib_name = kname.str();
|
||||
kernel_name = lib_name;
|
||||
}
|
||||
std::string kernel_name = fmt::format(
|
||||
"gather{0}{1}_{2}_{3}_{4}",
|
||||
type_to_name(out),
|
||||
idx_type_name,
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "size_t" : "uint");
|
||||
std::string lib_name = kernel_name;
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gather();
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::gather();
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str =
|
||||
nidx ? get_type_string(inputs[1].dtype()) : "bool";
|
||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||
|
||||
// Index dimension specializations
|
||||
kernel_source << fmt::format(
|
||||
kernel_source += fmt::format(
|
||||
gather_kernels,
|
||||
type_to_name(out) + idx_type_name,
|
||||
out_type_str,
|
||||
@@ -81,13 +85,14 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nidx,
|
||||
idx_args,
|
||||
idx_arr,
|
||||
idx_ndim);
|
||||
return kernel_source.str();
|
||||
idx_ndim,
|
||||
large ? "size_t" : "uint");
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
size_t slice_size = 1;
|
||||
for (auto s : slice_sizes_) {
|
||||
@@ -113,17 +118,17 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
std::vector<char> idx_contigs;
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end());
|
||||
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
|
||||
}
|
||||
|
||||
// Set all the buffers
|
||||
@@ -131,21 +136,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
// Set source info
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
|
||||
compute_encoder.set_vector_bytes(src.shape(), 2);
|
||||
compute_encoder.set_vector_bytes(src.strides(), 3);
|
||||
compute_encoder.set_bytes(ndim, 4);
|
||||
compute_encoder.set_vector_bytes(slice_sizes_, 5);
|
||||
compute_encoder.set_vector_bytes(axes_, 6);
|
||||
|
||||
// Set index info
|
||||
//
|
||||
// We don't need to check for empty idx_shapes because gather has a
|
||||
// idx_ndim == 0 specialization
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
|
||||
compute_encoder.set_vector_bytes(idx_shapes, 7);
|
||||
compute_encoder.set_vector_bytes(idx_strides, 8);
|
||||
compute_encoder.set_vector_bytes(idx_contigs, 9);
|
||||
compute_encoder.set_bytes(idx_ndim, 10);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
@@ -153,7 +157,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -172,12 +176,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Copy src into out
|
||||
auto copy_type =
|
||||
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||
CopyType copy_type;
|
||||
if (inputs[0].data_size() == 1) {
|
||||
copy_type = CopyType::Scalar;
|
||||
} else if (inputs[0].flags().row_contiguous) {
|
||||
copy_type = CopyType::Vector;
|
||||
} else {
|
||||
copy_type = CopyType::General;
|
||||
}
|
||||
copy_gpu(inputs[0], out, copy_type);
|
||||
|
||||
auto& upd = inputs.back();
|
||||
|
||||
// Empty update
|
||||
if (inputs.back().size() == 0) {
|
||||
if (upd.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -186,23 +198,22 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
bool index_nd1_specialization = (idx_ndim == 1);
|
||||
size_t idx_size = nidx ? inputs[1].size() : 1;
|
||||
|
||||
// Bail from fast path (1d index specialization) if scatter dims aren't
|
||||
// the outermost dims and contiguous since update access won't be raster
|
||||
// order.
|
||||
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
|
||||
index_nd1_specialization &= (axes_[i] == i);
|
||||
auto idx_to_out = idx_size / out.size();
|
||||
int nwork;
|
||||
if (idx_ndim <= 1 || idx_to_out < 1) {
|
||||
nwork = 1;
|
||||
} else if (idx_to_out <= 4) {
|
||||
nwork = 4;
|
||||
} else if (idx_to_out < 16) {
|
||||
nwork = 8;
|
||||
} else if (idx_to_out < 32) {
|
||||
nwork = 16;
|
||||
} else {
|
||||
nwork = 32;
|
||||
}
|
||||
|
||||
// Bail from fast path (1d index specialization) if any of the dims are
|
||||
// broadcasted, since we can't rely on linear indexing in that case.
|
||||
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
|
||||
index_nd1_specialization &= inputs[i].flags().row_contiguous;
|
||||
}
|
||||
|
||||
std::string lib_name;
|
||||
std::string kernel_name;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
@@ -222,23 +233,25 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op_name = "min";
|
||||
break;
|
||||
}
|
||||
|
||||
{
|
||||
std::ostringstream kname;
|
||||
if (index_nd1_specialization) {
|
||||
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
|
||||
} else {
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
}
|
||||
kname << "_" << op_name << "_" << nidx;
|
||||
lib_name = kname.str();
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
auto upd_contig = upd.flags().row_contiguous;
|
||||
bool large_out = out.size() > UINT32_MAX;
|
||||
bool large_idx = nidx && (inputs[1].size() > UINT32_MAX);
|
||||
bool large_upd = upd.size() > UINT32_MAX;
|
||||
bool large = large_out || large_idx || large_upd;
|
||||
std::string kernel_name = fmt::format(
|
||||
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
|
||||
type_to_name(out),
|
||||
idx_type_name,
|
||||
op_name,
|
||||
nidx,
|
||||
upd_contig ? "updc_true" : "updc_false",
|
||||
nwork,
|
||||
large ? "size_t" : "uint");
|
||||
std::string lib_name = kernel_name;
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils()
|
||||
<< metal::scatter();
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::reduce_utils(), metal::scatter());
|
||||
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str =
|
||||
@@ -266,7 +279,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||
|
||||
kernel_source << fmt::format(
|
||||
kernel_source += fmt::format(
|
||||
scatter_kernels,
|
||||
type_to_name(out) + idx_type_name + "_" + op_name,
|
||||
out_type_str,
|
||||
@@ -274,126 +287,105 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op_type,
|
||||
nidx,
|
||||
idx_args,
|
||||
idx_arr);
|
||||
return kernel_source.str();
|
||||
idx_arr,
|
||||
upd_contig,
|
||||
nwork,
|
||||
large ? "size_t" : "uint");
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
|
||||
auto& upd = inputs.back();
|
||||
size_t nthreads = upd.size();
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set all the buffers
|
||||
compute_encoder.set_input_array(upd, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set update info
|
||||
uint upd_ndim = upd.ndim();
|
||||
size_t upd_ndim = upd.ndim();
|
||||
size_t upd_size = 1;
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
if (index_nd1_specialization) {
|
||||
compute_encoder->setBytes(
|
||||
out.shape().data(), out.shape().size() * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
|
||||
|
||||
size_t out_ndim = out.ndim();
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
|
||||
if (upd_ndim <= 1) {
|
||||
// Placeholder so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 6);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
} else {
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end());
|
||||
}
|
||||
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
|
||||
} else {
|
||||
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out_ndim * sizeof(size_t), 8);
|
||||
}
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
// Set index info
|
||||
if (idx_ndim == 0) {
|
||||
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
|
||||
// error in the metal API.
|
||||
idx_shapes.push_back(0);
|
||||
idx_strides.push_back(0);
|
||||
}
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
// To access .data() use char instead of bool
|
||||
// bool is 1 byte in Metal so this is safe
|
||||
std::vector<char> idx_contigs;
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end());
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end());
|
||||
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
|
||||
}
|
||||
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 3);
|
||||
compute_encoder.set_bytes(stride_, 4);
|
||||
} else {
|
||||
compute_encoder.set_vector_bytes(upd.shape(), 3);
|
||||
compute_encoder.set_vector_bytes(upd.strides(), 4);
|
||||
}
|
||||
compute_encoder.set_bytes(upd_ndim, 5);
|
||||
compute_encoder.set_bytes(upd_size, 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 7);
|
||||
compute_encoder.set_bytes(stride_, 8);
|
||||
} else {
|
||||
compute_encoder.set_vector_bytes(out.shape(), 7);
|
||||
compute_encoder.set_vector_bytes(out.strides(), 8);
|
||||
}
|
||||
compute_encoder.set_bytes(out_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(axes_, 10);
|
||||
|
||||
// Set index info
|
||||
if (idx_ndim == 0) {
|
||||
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
|
||||
// error in the metal API.
|
||||
idx_shapes.push_back(0);
|
||||
idx_strides.push_back(0);
|
||||
idx_contigs.push_back(false);
|
||||
}
|
||||
compute_encoder.set_vector_bytes(idx_shapes, 11);
|
||||
compute_encoder.set_vector_bytes(idx_strides, 12);
|
||||
compute_encoder.set_vector_bytes(idx_contigs, 13);
|
||||
compute_encoder.set_bytes(idx_ndim, 14);
|
||||
compute_encoder.set_bytes(idx_size, 15);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
auto grid_y = (nthreads / upd_size);
|
||||
grid_y = (grid_y + nwork - 1) / nwork;
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, grid_y, 1);
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,7 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view gather_kernels = R"(
|
||||
[[kernel]] void gather{0}_{3}_{6}(
|
||||
[[kernel]] void gather{0}_{3}_{6}_{7}(
|
||||
const device {1}* src [[buffer(0)]],
|
||||
device {1}* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
@@ -11,14 +11,15 @@ constexpr std::string_view gather_kernels = R"(
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const constant int* idx_shapes [[buffer(7)]],
|
||||
const constant size_t* idx_strides [[buffer(8)]],
|
||||
const constant int& idx_ndim [[buffer(9)]],
|
||||
const constant bool* idx_contigs [[buffer(9)]],
|
||||
const constant int& idx_ndim [[buffer(10)]],
|
||||
{4}
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {{
|
||||
Indices<{2}, {3}> idxs{{
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||
|
||||
return gather_impl<{1}, {2}, {3}, {6}>(
|
||||
return gather_impl<{1}, {2}, {3}, {6}, {7}>(
|
||||
src,
|
||||
out,
|
||||
src_shape,
|
||||
@@ -33,32 +34,7 @@ constexpr std::string_view gather_kernels = R"(
|
||||
)";
|
||||
|
||||
constexpr std::string_view scatter_kernels = R"(
|
||||
[[kernel]] void scatter_1d_index{0}_{4}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& out_ndim [[buffer(5)]],
|
||||
const constant int* upd_shape [[buffer(6)]],
|
||||
const constant size_t& upd_ndim [[buffer(7)]],
|
||||
const constant size_t& upd_size [[buffer(8)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
|
||||
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
|
||||
updates,
|
||||
out,
|
||||
out_shape,
|
||||
out_strides,
|
||||
out_ndim,
|
||||
upd_shape,
|
||||
upd_ndim,
|
||||
upd_size,
|
||||
idx_buffers,
|
||||
gid);
|
||||
}}
|
||||
|
||||
[[kernel]] void scatter{0}_{4}(
|
||||
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
@@ -71,12 +47,14 @@ constexpr std::string_view scatter_kernels = R"(
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const constant int* idx_shapes [[buffer(11)]],
|
||||
const constant size_t* idx_strides [[buffer(12)]],
|
||||
const constant int& idx_ndim [[buffer(13)]],
|
||||
const constant bool* idx_contigs [[buffer(13)]],
|
||||
const constant int& idx_ndim [[buffer(14)]],
|
||||
const constant size_t& idx_size [[buffer(15)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||
|
||||
return scatter_impl<{1}, {2}, {3}, {4}>(
|
||||
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>(
|
||||
updates,
|
||||
out,
|
||||
upd_shape,
|
||||
@@ -87,6 +65,7 @@ constexpr std::string_view scatter_kernels = R"(
|
||||
out_strides,
|
||||
out_ndim,
|
||||
axes,
|
||||
idx_size,
|
||||
idxs,
|
||||
gid);
|
||||
}}
|
||||
|
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
||||
@@ -46,25 +45,27 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
auto in_t = get_type_string(in_type);
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
|
||||
kernel_source << get_template_definition(
|
||||
"v_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
kernel_source << get_template_definition(
|
||||
"v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source << get_template_definition(
|
||||
"gn4_" + lib_name, "unary_g", in_t, out_t, op, 4);
|
||||
return kernel_source.str();
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::unary_ops(), metal::unary());
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
kernel_source +=
|
||||
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
void add_binary_kernels(
|
||||
void append_binary_kernels(
|
||||
const std::string lib_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op,
|
||||
std::ostringstream& kernel_source) {
|
||||
std::string& kernel_source) {
|
||||
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
|
||||
{"ss", "binary_ss"},
|
||||
{"vs", "binary_vs"},
|
||||
@@ -74,26 +75,24 @@ void add_binary_kernels(
|
||||
{"sv2", "binary_sv2"},
|
||||
{"vv2", "binary_vv2"},
|
||||
{"g1", "binary_g_nd1"},
|
||||
{"g2", "binary_g_nd2"},
|
||||
{"g3", "binary_g_nd3"},
|
||||
{"g2large", "binary_g_nd2"},
|
||||
{"g3large", "binary_g_nd3"},
|
||||
}};
|
||||
auto in_t = get_type_string(in_type);
|
||||
auto out_t = get_type_string(out_type);
|
||||
|
||||
for (auto& [name, func] : kernel_types) {
|
||||
std::string template_def;
|
||||
template_def = get_template_definition(
|
||||
name + "_" + lib_name,
|
||||
func,
|
||||
get_type_string(in_type),
|
||||
get_type_string(out_type),
|
||||
op);
|
||||
kernel_source << template_def;
|
||||
kernel_source +=
|
||||
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
|
||||
}
|
||||
kernel_source << get_template_definition(
|
||||
"gn4_" + lib_name,
|
||||
"binary_g",
|
||||
get_type_string(in_type),
|
||||
get_type_string(out_type),
|
||||
op,
|
||||
4);
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
@@ -104,10 +103,11 @@ MTL::ComputePipelineState* get_binary_kernel(
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
|
||||
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
return kernel_source.str();
|
||||
std::string kernel_source;
|
||||
kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::binary_ops(), metal::binary());
|
||||
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -120,11 +120,10 @@ MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops()
|
||||
<< metal::binary_two();
|
||||
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
return kernel_source.str();
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::binary_ops(), metal::binary_two());
|
||||
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -136,24 +135,29 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
auto t_str = get_type_string(type);
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
|
||||
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
||||
{"v", "ternary_v"},
|
||||
{"v2", "ternary_v2"},
|
||||
{"g1", "ternary_g_nd1"},
|
||||
{"g2", "ternary_g_nd2"},
|
||||
{"g3", "ternary_g_nd3"},
|
||||
{"g2large", "ternary_g_nd2"},
|
||||
{"g3large", "ternary_g_nd3"},
|
||||
}};
|
||||
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
|
||||
for (auto& [name, func] : kernel_types) {
|
||||
std::string template_def;
|
||||
template_def = get_template_definition(
|
||||
name + "_" + lib_name, func, get_type_string(type), op);
|
||||
kernel_source << template_def;
|
||||
kernel_source +=
|
||||
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
||||
}
|
||||
kernel_source << get_template_definition(
|
||||
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
|
||||
return kernel_source.str();
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -165,31 +169,43 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::copy();
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source << metal::utils() << metal::copy()
|
||||
<< get_template_definition(
|
||||
"s_" + lib_name, "copy_s", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"v_" + lib_name, "copy_v", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
|
||||
<< get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
|
||||
return kernel_source.str();
|
||||
kernel_source +=
|
||||
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "copy_g_nd3", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
|
||||
kernel_source += get_template_definition(
|
||||
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg3large_" + lib_name, "copy_gg_nd3", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"ggn4large_" + lib_name, "copy_gg", in_type, out_type, 4);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -319,17 +335,19 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const Dtype& out_type) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
std::string op_type = op_name(out);
|
||||
op_type[0] = std::toupper(op_name(out)[0]);
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::string op = op_type + "<" + out_type + ">";
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, "init_reduce", out_type, op);
|
||||
return kernel_source.str();
|
||||
std::string op_type = op_name;
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::string op = op_type + "<" + out_t + ">";
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::reduce_utils();
|
||||
kernel_source += metal::reduce();
|
||||
kernel_source += get_template_definition(kernel_name, func_name, out_t, op);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -339,30 +357,31 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
const Dtype& in_type,
|
||||
const Dtype& out_type,
|
||||
const std::string& idx_t,
|
||||
int ndim /* = -1 */,
|
||||
int bm /* = -1 */,
|
||||
int bn /* = -1 */) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::string op_type = op_name;
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
std::ostringstream kernel_source;
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::string op = op_type + "<" + out_type + ">";
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||
auto in_t = get_type_string(in_type);
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::string op = op_type + "<" + out_t + ">";
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::reduce_utils(), metal::reduce());
|
||||
if (bm >= 0) {
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op, ndim, bm, bn);
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn);
|
||||
} else if (ndim >= 0) {
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op, ndim);
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, func_name, in_t, out_t, op, idx_t, ndim);
|
||||
} else {
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op);
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, func_name, in_t, out_t, op, idx_t);
|
||||
}
|
||||
return kernel_source.str();
|
||||
return kernel_source;
|
||||
});
|
||||
auto st = d.get_kernel(kernel_name, lib);
|
||||
return st;
|
||||
|
@@ -79,15 +79,18 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const Dtype& out_type);
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
const Dtype& in_type,
|
||||
const Dtype& out_type,
|
||||
const std::string& idx_t,
|
||||
int ndim = -1,
|
||||
int bm = -1,
|
||||
int bn = -1);
|
||||
|
@@ -1,13 +1,27 @@
|
||||
set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h)
|
||||
set(BASE_HEADERS
|
||||
metal_3_1/bf16.h
|
||||
metal_3_0/bf16.h
|
||||
bf16_math.h
|
||||
complex.h
|
||||
defines.h
|
||||
expm1f.h
|
||||
utils.h)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||
endif()
|
||||
if(MLX_METAL_VERSION GREATER_EQUAL 310)
|
||||
set(VERSION_INCLUDES
|
||||
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
|
||||
else()
|
||||
set(VERSION_INCLUDES
|
||||
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_0)
|
||||
endif()
|
||||
add_custom_command(
|
||||
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
||||
-I${PROJECT_SOURCE_DIR} -I${VERSION_INCLUDES} -o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||
OUTPUT ${TARGET}.air
|
||||
COMMENT "Building ${TARGET}.air"
|
||||
@@ -30,9 +44,7 @@ build_kernel(layer_norm)
|
||||
build_kernel(random)
|
||||
build_kernel(rms_norm)
|
||||
build_kernel(rope)
|
||||
build_kernel(
|
||||
scaled_dot_product_attention scaled_dot_product_attention_params.h
|
||||
sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)
|
||||
build_kernel(scaled_dot_product_attention sdpa_vector.h)
|
||||
|
||||
set(STEEL_HEADERS
|
||||
steel/defines.h
|
||||
@@ -50,7 +62,27 @@ set(STEEL_HEADERS
|
||||
steel/gemm/transforms.h
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h)
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h)
|
||||
|
||||
set(STEEL_ATTN_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
steel/gemm/gemm.h
|
||||
steel/gemm/mma.h
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h
|
||||
steel/attn/attn.h
|
||||
steel/attn/loader.h
|
||||
steel/attn/mma.h
|
||||
steel/attn/params.h
|
||||
steel/attn/transforms.h
|
||||
steel/attn/kernels/steel_attention.h)
|
||||
|
||||
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
|
||||
|
||||
if(NOT MLX_METAL_JIT)
|
||||
build_kernel(arange arange.h)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/arange.h"
|
||||
|
||||
#define instantiate_arange(tname, type) \
|
||||
|
@@ -2,8 +2,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Metal math for bfloat16
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -369,18 +367,6 @@ instantiate_metal_math_funcs(
|
||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||
}
|
||||
|
||||
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
|
||||
|
||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||
|
||||
#else
|
||||
|
||||
#define bfloat16_to_uint16(x) x.bits_
|
||||
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
|
||||
|
||||
#endif
|
||||
|
||||
namespace metal {
|
||||
|
||||
instantiate_metal_simd_comm_funcs(
|
||||
|
@@ -77,12 +77,12 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
|
||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -91,13 +91,13 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -106,14 +106,18 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N = 1>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = size_t>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -124,13 +128,12 @@ template <typename T, typename U, typename Op, int N = 1>
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(
|
||||
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
|
||||
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||
auto xshape = shape[ndim - 1];
|
||||
size_t out_idx =
|
||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_xstride = a_strides[ndim - 1];
|
||||
auto b_xstride = b_strides[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
IdxT a_xstride = a_strides[ndim - 1];
|
||||
IdxT b_xstride = b_strides[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
|
||||
idx.x += a_xstride;
|
||||
|
@@ -9,18 +9,21 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
|
@@ -99,14 +99,14 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -116,15 +116,15 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -134,16 +134,20 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N = 1>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = size_t>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -155,13 +159,12 @@ template <typename T, typename U, typename Op, int N = 1>
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(
|
||||
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
|
||||
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||
auto xshape = shape[ndim - 1];
|
||||
size_t out_idx =
|
||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_xstride = a_strides[ndim - 1];
|
||||
auto b_xstride = b_strides[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
IdxT a_xstride = a_strides[ndim - 1];
|
||||
IdxT b_xstride = b_strides[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
auto out = Op()(a[idx.x], b[idx.y]);
|
||||
c[out_idx] = out[0];
|
||||
|
@@ -7,18 +7,21 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_float(op) \
|
||||
instantiate_binary_all(op, float16, half, half) \
|
||||
|
@@ -4,8 +4,8 @@
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
|
@@ -42,36 +42,36 @@ template <typename T, typename U>
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_g_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
|
||||
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
|
||||
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_g_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
|
||||
IdxT dst_idx =
|
||||
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N = 1>
|
||||
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_g(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
@@ -80,17 +80,16 @@ template <typename T, typename U, int N = 1>
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc(
|
||||
auto src_idx = elem_to_loc<int64_t, IdxT>(
|
||||
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
|
||||
if (N == 1) {
|
||||
int64_t dst_idx =
|
||||
index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
|
||||
IdxT dst_idx =
|
||||
index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
return;
|
||||
}
|
||||
auto xshape = src_shape[ndim - 1];
|
||||
int64_t dst_idx =
|
||||
N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
|
||||
IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
auto src_xstride = src_strides[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
|
||||
@@ -105,36 +104,36 @@ template <typename T, typename U>
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1<int64_t, int>(index, dst_stride);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2<int64_t, IdxT>(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3<int64_t, IdxT>(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N = 1>
|
||||
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
@@ -143,7 +142,7 @@ template <typename T, typename U, int N = 1>
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(
|
||||
auto idx = elem_to_loc_2_nd<int64_t, IdxT>(
|
||||
{N * index.x, index.y, index.z},
|
||||
src_shape,
|
||||
src_strides,
|
||||
@@ -153,8 +152,8 @@ template <typename T, typename U, int N = 1>
|
||||
dst[idx.y] = static_cast<U>(src[idx.x]);
|
||||
return;
|
||||
}
|
||||
auto src_xstride = src_strides[ndim - 1];
|
||||
auto dst_xstride = dst_strides[ndim - 1];
|
||||
IdxT src_xstride = src_strides[ndim - 1];
|
||||
IdxT dst_xstride = dst_strides[ndim - 1];
|
||||
auto xshape = src_shape[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
dst[idx.y] = static_cast<U>(src[idx.x]);
|
||||
|
@@ -2,22 +2,27 @@
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/copy.h"
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
|
||||
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
|
||||
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
|
||||
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
|
||||
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
|
||||
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||
|
@@ -4,7 +4,7 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
|
||||
METAL_FUNC void gather_impl(
|
||||
const device T* src [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
@@ -16,34 +16,36 @@ METAL_FUNC void gather_impl(
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
size_t src_idx = 0;
|
||||
LocT src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
size_t idx_loc;
|
||||
LocT idx_loc;
|
||||
if (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else if (IDX_NDIM == 1) {
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
||||
} else {
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
idx_loc += elem_to_loc(
|
||||
index.y,
|
||||
&indices.shapes[indices.ndim * i + 1],
|
||||
&indices.strides[indices.ndim * i + 1],
|
||||
indices.ndim - 1);
|
||||
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
||||
idx_loc += indices.row_contiguous[i]
|
||||
? index.y
|
||||
: elem_to_loc<size_t, LocT>(
|
||||
index.y,
|
||||
&indices.shapes[indices.ndim * i + 1],
|
||||
&indices.strides[indices.ndim * i + 1],
|
||||
indices.ndim - 1);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
|
||||
auto src_offset =
|
||||
elem_to_loc<size_t, LocT>(index.z, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
size_t out_idx = index.z;
|
||||
LocT out_idx = index.z;
|
||||
if (IDX_NDIM == 1) {
|
||||
out_idx += static_cast<size_t>(grid_dim.z) * index.x;
|
||||
out_idx += static_cast<LocT>(grid_dim.z) * index.x;
|
||||
} else if (IDX_NDIM >= 2) {
|
||||
out_idx +=
|
||||
grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
|
||||
out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y);
|
||||
}
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
}
|
||||
|
@@ -3,8 +3,6 @@
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
@@ -912,4 +910,4 @@ template <
|
||||
// clang-format off
|
||||
instantiate_gemv_t_bs_blocks(float32, float);
|
||||
instantiate_gemv_t_bs_blocks(float16, half);
|
||||
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
|
@@ -4,8 +4,6 @@
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/gemv_masked.h"
|
||||
|
@@ -9,11 +9,12 @@ struct Indices {
|
||||
const array<const device IdxT*, NIDX> buffers;
|
||||
const constant int* shapes;
|
||||
const constant size_t* strides;
|
||||
const constant bool* row_contiguous;
|
||||
const int ndim;
|
||||
};
|
||||
|
||||
template <typename IdxT>
|
||||
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
|
||||
if (is_unsigned_v<IdxT>) {
|
||||
return idx;
|
||||
} else {
|
||||
|
16
mlx/backend/metal/kernels/jit/bf16.h
Normal file
16
mlx/backend/metal/kernels/jit/bf16.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#define jit_if #if
|
||||
#define jit_else #else
|
||||
#define jit_endif #endif
|
||||
|
||||
jit_if (__METAL_VERSION__ >= 310)
|
||||
|
||||
#include "mlx/backend/metal/kernels/metal_3_1/bf16.h"
|
||||
|
||||
jit_else
|
||||
|
||||
#include "mlx/backend/metal/kernels/metal_3_0/bf16.h"
|
||||
|
||||
jit_endif // clang-format on
|
@@ -3,8 +3,6 @@
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
@@ -6,12 +6,6 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
|
||||
#else
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Helpers
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
@@ -311,7 +305,10 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
||||
} // namespace metal
|
||||
|
||||
#pragma METAL internals : disable
|
||||
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
|
||||
return x.bits_;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16_math.h"
|
||||
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
|
||||
return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
16
mlx/backend/metal/kernels/metal_3_1/bf16.h
Normal file
16
mlx/backend/metal/kernels/metal_3_1/bf16.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
|
||||
return as_type<uint16_t>(x);
|
||||
}
|
||||
|
||||
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
|
||||
return as_type<bfloat16_t>(x);
|
||||
}
|
@@ -564,18 +564,21 @@ METAL_FUNC void qmv_impl(
|
||||
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
||||
0,
|
||||
values_per_thread);
|
||||
U sum =
|
||||
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||
if (remaining > 0) {
|
||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(
|
||||
x, x_thread, remaining);
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
const device uint8_t* wl =
|
||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
const device uint8_t* wl =
|
||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] +=
|
||||
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
}
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
@@ -619,21 +622,22 @@ METAL_FUNC void qmv_impl(
|
||||
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
||||
0,
|
||||
values_per_thread);
|
||||
U sum =
|
||||
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||
if (remaining > 0) {
|
||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(
|
||||
x, x_thread, remaining);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
const device uint8_t* wl =
|
||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
const device uint8_t* wl =
|
||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot_safe<U, values_per_thread, bits>(
|
||||
wl, x_thread, s, b, sum, remaining);
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot_safe<U, values_per_thread, bits>(
|
||||
wl, x_thread, s, b, sum, remaining);
|
||||
}
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
if (simd_lid == 0) {
|
||||
@@ -650,8 +654,8 @@ METAL_FUNC void qvm_impl(
|
||||
const device T* biases,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
const constant int& out_vec_size,
|
||||
const int in_vec_size,
|
||||
const int out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -1298,6 +1302,61 @@ template <typename T, const int group_size, const int bits, bool batched>
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits, int split_k = 32>
|
||||
[[kernel]] void qvm_split_k(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int& final_block_size [[buffer(15)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
|
||||
// When (in_vec_size % split_k != 0) the final block needs to be smaller
|
||||
int in_vec_size_adj =
|
||||
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
|
||||
|
||||
qvm_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size_adj,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
|
@@ -51,6 +51,15 @@
|
||||
D, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_split_k(name, type, group_size, bits, split_k) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_spk_" #split_k, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
split_k)
|
||||
|
||||
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 1) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
||||
@@ -84,11 +93,16 @@
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
|
||||
|
||||
#define instantiate_quantized_all_splitk(type, group_size, bits) \
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
||||
|
||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||
instantiate_quantized_all_single(type, group_size, bits) \
|
||||
instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_all_quad(type, group_size, bits)
|
||||
instantiate_quantized_all_quad(type, group_size, bits) \
|
||||
instantiate_quantized_all_splitk(type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_types(group_size, bits) \
|
||||
instantiate_quantized_funcs(float, group_size, bits) \
|
||||
|
@@ -34,8 +34,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
[[kernel]] void rbitsc(
|
||||
device const uint32_t* keys,
|
||||
device char* out,
|
||||
device const bool& odd,
|
||||
device const uint& bytes_per_key,
|
||||
constant const bool& odd,
|
||||
constant const uint& bytes_per_key,
|
||||
uint2 grid_dim [[threads_per_grid]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto kidx = 2 * index.x;
|
||||
@@ -67,8 +67,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
[[kernel]] void rbits(
|
||||
device const uint32_t* keys,
|
||||
device char* out,
|
||||
device const bool& odd,
|
||||
device const uint& bytes_per_key,
|
||||
constant const bool& odd,
|
||||
constant const uint& bytes_per_key,
|
||||
constant const int& ndim,
|
||||
constant const int* key_shape,
|
||||
constant const size_t* key_strides,
|
||||
|
@@ -10,178 +10,156 @@
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
|
||||
#define instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
inst_f(name, float16, half, op) \
|
||||
inst_f(name, float32, float, op) \
|
||||
inst_f(name, bfloat16, bfloat16_t, op)
|
||||
#define instantiate_init_reduce(name, tname, type, op) \
|
||||
instantiate_kernel("init_reduce_" #name #tname, init_reduce, type, op<type>)
|
||||
|
||||
#define instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
inst_f(name, uint8, uint8_t, op) \
|
||||
inst_f(name, uint16, uint16_t, op) \
|
||||
inst_f(name, uint32, uint32_t, op)
|
||||
instantiate_init_reduce(and, bool_, bool, And)
|
||||
instantiate_init_reduce(or, bool_, bool, Or)
|
||||
|
||||
#define instantiate_reduce_helper_ints(inst_f, name, op) \
|
||||
inst_f(name, int8, int8_t, op) \
|
||||
inst_f(name, int16, int16_t, op) \
|
||||
inst_f(name, int32, int32_t, op)
|
||||
#define instantiate_init_sum_prod(name, op) \
|
||||
instantiate_init_reduce(name, int32, int32_t, op) \
|
||||
instantiate_init_reduce(name, int64, int64_t, op) \
|
||||
instantiate_init_reduce(name, float16, float16_t, op) \
|
||||
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
|
||||
instantiate_init_reduce(name, float32, float, op) \
|
||||
instantiate_init_reduce(name, complex64, complex64_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_64b(inst_f, name, op) \
|
||||
inst_f(name, int64, int64_t, op) \
|
||||
inst_f(name, uint64, uint64_t, op) \
|
||||
inst_f(name, complex64, complex64_t, op)
|
||||
instantiate_init_sum_prod(sum, Sum)
|
||||
instantiate_init_sum_prod(prod, Prod)
|
||||
|
||||
#define instantiate_reduce_helper_types(inst_f, name, op) \
|
||||
instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
instantiate_reduce_helper_ints(inst_f, name, op)
|
||||
#define instantiate_init_min_max(name, op) \
|
||||
instantiate_init_reduce(name, bool_, bool, op) \
|
||||
instantiate_init_reduce(name, int8, int8_t, op) \
|
||||
instantiate_init_reduce(name, int16, int16_t, op) \
|
||||
instantiate_init_reduce(name, int32, int32_t, op) \
|
||||
instantiate_init_reduce(name, int64, int64_t, op) \
|
||||
instantiate_init_reduce(name, uint8, uint8_t, op) \
|
||||
instantiate_init_reduce(name, uint16, uint16_t, op) \
|
||||
instantiate_init_reduce(name, uint32, uint32_t, op) \
|
||||
instantiate_init_reduce(name, uint64, uint64_t, op) \
|
||||
instantiate_init_reduce(name, float16, float16_t, op) \
|
||||
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
|
||||
instantiate_init_reduce(name, float32, float, op) \
|
||||
instantiate_init_reduce(name, complex64, complex64_t, op)
|
||||
|
||||
#define instantiate_reduce_ops(inst_f, type_f) \
|
||||
type_f(inst_f, sum, Sum) \
|
||||
type_f(inst_f, prod, Prod) \
|
||||
type_f(inst_f, min, Min) \
|
||||
type_f(inst_f, max, Max)
|
||||
|
||||
// Special case for bool reductions
|
||||
#define instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, tname, itype, otype, op) \
|
||||
inst_f(name##tname, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, bool_, bool, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint8, uint8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint16, uint16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint32, uint32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint64, uint64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int8, int8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int16, int16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int32, int32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int64, int64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, float16, half, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
float32, \
|
||||
float, \
|
||||
otype, \
|
||||
op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
bfloat16, \
|
||||
bfloat16_t, \
|
||||
otype, \
|
||||
op)
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
instantiate_kernel("init_reduce_" #name, \
|
||||
init_reduce, \
|
||||
otype, op)
|
||||
|
||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||
instantiate_init_reduce(name##tname, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_init_reduce(andbool_, bool, And<bool>)
|
||||
instantiate_init_reduce(orbool_, bool, Or<bool>)
|
||||
instantiate_init_min_max(min, Min)
|
||||
instantiate_init_min_max(max, Max)
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
instantiate_kernel("all_reduce_" #name, \
|
||||
all_reduce, \
|
||||
itype, otype, op)
|
||||
|
||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce(name##tname, type, type, op<type>)
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, size_t, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_large_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, size_t, dim)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_64b)
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, size_t, dim, bm, bn)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
|
||||
|
||||
// special case bool with larger output type
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, dim)
|
||||
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, dim, bm, bn)
|
||||
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, size_t, dim, bm, bn)
|
||||
|
||||
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
|
||||
instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 0) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 1) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 2) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 3) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 4) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 0) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 5) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 1) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 2) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 3) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 4)
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 5)
|
||||
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_general(name##tname, type, type, op<type>)
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, size_t, dim)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
|
||||
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, dim)
|
||||
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, dim)
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, size_t, dim)
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 4) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 5) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 5) \
|
||||
instantiate_kernel("row_reduce_simple_" #name, \
|
||||
row_reduce_simple, \
|
||||
itype, otype, op)
|
||||
|
||||
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general(name##tname, type, type, op<type>)
|
||||
#define instantiate_reduce_functions(name, tname, itype, otype, op) \
|
||||
instantiate_all_reduce(name##tname, itype, otype, op<otype>) \
|
||||
instantiate_row_reduce_general(name##tname, itype, otype, op<otype>) \
|
||||
instantiate_col_reduce_general(name##tname, itype, otype, op<otype>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_64b)
|
||||
#define instantiate_and_or(name, op) \
|
||||
instantiate_reduce_functions(name, bool_, bool, bool, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, bool, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, bool, op) \
|
||||
instantiate_reduce_functions(name, int64, int64_t, bool, op)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)
|
||||
instantiate_and_or(and, And)
|
||||
instantiate_and_or(or, Or)
|
||||
|
||||
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
#define instantiate_sum_prod(name, op) \
|
||||
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
|
||||
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
|
||||
instantiate_reduce_functions(name, float32, float, float, op) \
|
||||
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
|
||||
|
||||
instantiate_sum_prod(sum, Sum)
|
||||
instantiate_sum_prod(prod, Prod)
|
||||
|
||||
#define instantiate_min_max(name, op) \
|
||||
instantiate_reduce_functions(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_reduce_functions(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_reduce_functions(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
|
||||
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
|
||||
instantiate_reduce_functions(name, float32, float, float, op) \
|
||||
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
|
||||
|
||||
instantiate_min_max(min, Min)
|
||||
instantiate_min_max(max, Max)
|
||||
// clang-format on
|
||||
|
@@ -1,6 +1,11 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT = int64_t,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
@@ -16,10 +21,10 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
threadgroup U shared_vals[simd_size];
|
||||
|
||||
U total = Op::init;
|
||||
int64_t start_idx = gid.y * row_size;
|
||||
int64_t actual_row =
|
||||
IdxT start_idx = gid.y * IdxT(row_size);
|
||||
IdxT actual_row =
|
||||
(start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
|
||||
int64_t blocks = actual_row / (lsize.x * N_READS);
|
||||
IdxT blocks = actual_row / (lsize.x * N_READS);
|
||||
int extra = actual_row - blocks * (lsize.x * N_READS);
|
||||
extra -= lid.x * N_READS;
|
||||
start_idx += lid.x * N_READS;
|
||||
@@ -30,7 +35,7 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
extra = 0;
|
||||
}
|
||||
|
||||
for (int64_t b = 0; b < blocks; b++) {
|
||||
for (IdxT b = 0; b < blocks; b++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total = op(static_cast<U>(in[i]), total);
|
||||
}
|
||||
|
@@ -1,11 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIMS,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
||||
[[kernel]] void col_reduce_small(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
@@ -20,170 +15,129 @@ template <
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[thread_position_in_grid]],
|
||||
uint3 tsize [[threads_per_grid]]) {
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
constexpr int n_reads = 4;
|
||||
Op op;
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
|
||||
// Case 1: Small row small column
|
||||
if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
|
||||
U totals[31];
|
||||
for (int i = 0; i < 31; i++) {
|
||||
totals[i] = Op::init;
|
||||
U totals[n_reads];
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads;
|
||||
if (column >= reduction_stride) {
|
||||
return;
|
||||
}
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
loop.next(lid.y, reduce_shape, reduce_strides);
|
||||
for (IdxT r = lid.y; r < total_rows; r += lsize.y) {
|
||||
row = in + loop.location();
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
||||
}
|
||||
} else {
|
||||
U vals[n_reads];
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
vals[i] =
|
||||
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
|
||||
}
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
}
|
||||
loop.next(lsize.y, reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
short stride = reduction_stride;
|
||||
short size = reduction_size;
|
||||
short blocks = stride / N_READS;
|
||||
short extra = stride - blocks * N_READS;
|
||||
|
||||
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim);
|
||||
|
||||
for (uint r = 0; r < non_col_reductions; r++) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
for (short i = 0; i < size; i++) {
|
||||
for (short j = 0; j < blocks; j++) {
|
||||
for (short k = 0; k < N_READS; k++) {
|
||||
totals[j * N_READS + k] =
|
||||
op(totals[j * N_READS + k],
|
||||
static_cast<U>(row[i * stride + j * N_READS + k]));
|
||||
}
|
||||
}
|
||||
for (short k = 0; k < extra; k++) {
|
||||
totals[blocks * N_READS + k] =
|
||||
op(totals[blocks * N_READS + k],
|
||||
static_cast<U>(row[i * stride + blocks * N_READS + k]));
|
||||
if (lsize.y > 1) {
|
||||
// lsize.y should be <= 8
|
||||
threadgroup U shared_vals[32 * 8 * n_reads];
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (lid.y == 0) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = shared_vals[lid.x * n_reads + i];
|
||||
}
|
||||
for (uint j = 1; j < lsize.y; j++) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] =
|
||||
op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],
|
||||
totals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(reduce_shape, reduce_strides);
|
||||
}
|
||||
out += out_idx * reduction_stride;
|
||||
for (short j = 0; j < stride; j++) {
|
||||
out[j] = totals[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Case 2: Long row small column
|
||||
else if (reduction_size * non_col_reductions < 32) {
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
short size = reduction_size;
|
||||
size_t offset = size_t(tid.x) * N_READS;
|
||||
bool safe = offset + N_READS <= reduction_stride;
|
||||
short extra = reduction_stride - offset;
|
||||
|
||||
size_t out_idx = tid.y + tsize.z * size_t(tid.z);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim) + offset;
|
||||
|
||||
for (uint r = 0; r < non_col_reductions; r++) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
if (safe) {
|
||||
for (short i = 0; i < size; i++) {
|
||||
for (short j = 0; j < N_READS; j++) {
|
||||
totals[j] =
|
||||
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < size; i++) {
|
||||
for (short j = 0; j < extra; j++) {
|
||||
totals[j] =
|
||||
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(reduce_shape, reduce_strides);
|
||||
}
|
||||
out += out_idx * reduction_stride + offset;
|
||||
if (lid.y == 0) {
|
||||
out += out_idx * IdxT(reduction_stride) + column;
|
||||
if (safe) {
|
||||
for (short i = 0; i < N_READS; i++) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < extra; i++) {
|
||||
for (int i = 0; column + i < reduction_stride; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Case 3: Long row medium column
|
||||
else {
|
||||
threadgroup U shared_vals[1024];
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
short stride = reduction_stride;
|
||||
short lid = simd_group_id * simd_size + simd_lane_id;
|
||||
short2 tile((stride + N_READS - 1) / N_READS, 32);
|
||||
short2 offset((lid % tile.x) * N_READS, lid / tile.x);
|
||||
short sm_stride = tile.x * N_READS;
|
||||
bool safe = offset.x + N_READS <= stride;
|
||||
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x;
|
||||
|
||||
// Read cooperatively and contiguously and aggregate the partial results.
|
||||
size_t total = non_col_reductions * reduction_size;
|
||||
loop.next(offset.y, reduce_shape, reduce_strides);
|
||||
for (size_t r = offset.y; r < total; r += simd_size) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
if (safe) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
||||
}
|
||||
} else {
|
||||
U vals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) : op.init;
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(simd_size, reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
// Each thread holds N_READS partial results but the simdgroups are not
|
||||
// aligned to do the reduction across the simdgroup so we write our results
|
||||
// in the shared memory and read them back according to the simdgroup.
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op.simd_reduce(
|
||||
shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);
|
||||
}
|
||||
|
||||
// Write the output.
|
||||
if (simd_lane_id == 0) {
|
||||
short column = simd_group_id * N_READS;
|
||||
out += out_idx * reduction_stride + column;
|
||||
if (column + N_READS <= stride) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; column + i < stride; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
||||
[[kernel]] void col_reduce_longcolumn(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
const constant size_t& out_size [[buffer(11)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
Op op;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
|
||||
IdxT out_idx = gid.x + gsize.x * IdxT(gid.y);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + lid.x;
|
||||
|
||||
U total = Op::init;
|
||||
IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
|
||||
for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows;
|
||||
r += lsize.y * gsize.z) {
|
||||
row = in + loop.location();
|
||||
total = op(static_cast<U>(*row), total);
|
||||
loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
threadgroup U shared_vals[32 * 32];
|
||||
shared_vals[lid.y * lsize.x + lid.x] = total;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (lid.y == 0) {
|
||||
for (uint i = 1; i < lsize.y; i++) {
|
||||
total = op(total, shared_vals[i * lsize.x + lid.x]);
|
||||
}
|
||||
out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] =
|
||||
total;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,7 +152,14 @@ template <
|
||||
* totals with a loop.
|
||||
* 7. Write them to the output
|
||||
*/
|
||||
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT,
|
||||
int NDIMS,
|
||||
int BM,
|
||||
int BN>
|
||||
[[kernel]] void col_reduce_looped(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
@@ -216,14 +177,14 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
constexpr int n_simdgroups = 4;
|
||||
constexpr int n_simdgroups = 8;
|
||||
constexpr short tgp_size = n_simdgroups * simd_size;
|
||||
constexpr short n_reads = (BM * BN) / tgp_size;
|
||||
constexpr short n_read_blocks = BN / n_reads;
|
||||
|
||||
threadgroup U shared_vals[BN * BM];
|
||||
U totals[n_reads];
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
@@ -232,17 +193,17 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
short lid = simd_group_id * simd_size + simd_lane_id;
|
||||
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
|
||||
size_t column = BN * gid.x + offset.x;
|
||||
IdxT column = BN * gid.x + offset.x;
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
size_t total = non_col_reductions * reduction_size;
|
||||
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
loop.next(offset.y, reduce_shape, reduce_strides);
|
||||
for (size_t r = offset.y; r < total; r += BM) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
for (IdxT r = offset.y; r < total; r += BM) {
|
||||
row = in + loop.location();
|
||||
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
@@ -282,8 +243,8 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
// Write the output.
|
||||
if (simd_lane_id == 0) {
|
||||
size_t out_column = BN * gid.x + out_offset.x;
|
||||
out += out_idx * reduction_stride + out_column;
|
||||
IdxT out_column = BN * gid.x + out_offset.x;
|
||||
out += out_idx * IdxT(reduction_stride) + out_column;
|
||||
if (out_column + n_outputs <= reduction_stride) {
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
out[i] = totals[i];
|
||||
@@ -316,7 +277,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
// Write the output.
|
||||
if (offset.y == 0) {
|
||||
out += out_idx * reduction_stride + column;
|
||||
out += out_idx * IdxT(reduction_stride) + column;
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
out[i] = totals[i];
|
||||
@@ -329,3 +290,109 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT,
|
||||
int NDIMS,
|
||||
int BM,
|
||||
int BN>
|
||||
[[kernel]] void col_reduce_2pass(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
const constant size_t& out_size [[buffer(11)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
constexpr int n_simdgroups = 8;
|
||||
constexpr short tgp_size = n_simdgroups * simd_size;
|
||||
constexpr short n_reads = (BM * BN) / tgp_size;
|
||||
constexpr short n_read_blocks = BN / n_reads;
|
||||
constexpr int n_outputs = BN / n_simdgroups;
|
||||
constexpr short outer_blocks = 32;
|
||||
static_assert(BM == 32, "BM should be equal to 32");
|
||||
|
||||
threadgroup U shared_vals[BN * BM];
|
||||
U totals[n_reads];
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
short lid = simd_group_id * simd_size + simd_lane_id;
|
||||
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
|
||||
IdxT column = BN * gid.x + offset.x;
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
IdxT full_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
IdxT block_idx = full_idx / IdxT(out_size);
|
||||
IdxT out_idx = full_idx % IdxT(out_size);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
|
||||
for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) {
|
||||
row = in + loop.location();
|
||||
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
||||
}
|
||||
} else {
|
||||
U vals[n_reads];
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
vals[i] =
|
||||
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
|
||||
}
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(outer_blocks * BM, reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
// We can use a simd reduction to accumulate across BM so each thread writes
|
||||
// the partial output to SM and then each simdgroup does BN / n_simdgroups
|
||||
// accumulations.
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
shared_vals[offset.y * BN + offset.x + i] = totals[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
totals[i] =
|
||||
op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
|
||||
}
|
||||
|
||||
// Write the output.
|
||||
if (simd_lane_id == 0) {
|
||||
IdxT out_column = BN * gid.x + out_offset.x;
|
||||
out += full_idx * IdxT(reduction_stride) + out_column;
|
||||
if (out_column + n_outputs <= reduction_stride) {
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; out_column + i < reduction_stride; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -193,6 +193,7 @@ template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT,
|
||||
int NDIMS,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_small(
|
||||
@@ -214,20 +215,20 @@ template <
|
||||
Op op;
|
||||
|
||||
U total_val = Op::init;
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
|
||||
// Precompute some row reduction numbers
|
||||
const device T* row;
|
||||
int blocks = row_size / N_READS;
|
||||
int extra = row_size % N_READS;
|
||||
int blocks = IdxT(row_size) / N_READS;
|
||||
int extra = IdxT(row_size) % N_READS;
|
||||
|
||||
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
|
||||
// Simple loop over non_row_reductions and reduce the row in the thread.
|
||||
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim);
|
||||
IdxT out_idx = tid.x + tsize.y * IdxT(tid.y);
|
||||
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
|
||||
for (uint r = 0; r < non_row_reductions; r++) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
row = in + loop.location();
|
||||
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
|
||||
loop.next(reduce_shape, reduce_strides);
|
||||
}
|
||||
@@ -236,13 +237,13 @@ template <
|
||||
} else {
|
||||
// Collaboratively reduce over non_row_reductions in the simdgroup. Each
|
||||
// thread reduces every 32nd row and then a simple simd reduce.
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim);
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
|
||||
loop.next(simd_lane_id, reduce_shape, reduce_strides);
|
||||
|
||||
for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
row = in + loop.location();
|
||||
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
|
||||
loop.next(simd_size, reduce_shape, reduce_strides);
|
||||
}
|
||||
@@ -259,6 +260,7 @@ template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT = size_t,
|
||||
int N_READS = REDUCE_N_READS,
|
||||
int N_WRITES = REDUCE_N_WRITES>
|
||||
[[kernel]] void row_reduce_simple(
|
||||
@@ -277,15 +279,15 @@ template <
|
||||
U totals[N_WRITES];
|
||||
|
||||
// Move to the row
|
||||
size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z));
|
||||
IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z));
|
||||
if (out_idx + N_WRITES > out_size) {
|
||||
out_idx = out_size - N_WRITES;
|
||||
}
|
||||
in += out_idx * reduction_size;
|
||||
in += out_idx * IdxT(reduction_size);
|
||||
out += out_idx;
|
||||
|
||||
// Each thread reduces across the row
|
||||
int blocks = reduction_size / (lsize.x * N_READS);
|
||||
int blocks = IdxT(reduction_size) / (lsize.x * N_READS);
|
||||
int extra = reduction_size - blocks * (lsize.x * N_READS);
|
||||
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
|
||||
totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
|
||||
@@ -306,6 +308,7 @@ template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT,
|
||||
int NDIMS,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_looped(
|
||||
@@ -330,19 +333,20 @@ template <
|
||||
threadgroup U shared_vals[simd_size];
|
||||
U total = Op::init;
|
||||
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
|
||||
// lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
|
||||
// needs a small refactor.
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS;
|
||||
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim) +
|
||||
lid.x * N_READS;
|
||||
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
int blocks = row_size / (lsize.x * N_READS);
|
||||
int blocks = IdxT(row_size) / (lsize.x * N_READS);
|
||||
int extra = row_size - blocks * (lsize.x * N_READS);
|
||||
|
||||
for (size_t i = 0; i < non_row_reductions; i++) {
|
||||
row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim);
|
||||
for (IdxT i = 0; i < non_row_reductions; i++) {
|
||||
row = in + loop.location();
|
||||
|
||||
// Each thread reduces across the row
|
||||
U row_total;
|
||||
|
@@ -3,8 +3,6 @@
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
@@ -17,12 +15,15 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]],
|
||||
threadgroup float* local_sums [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_inv_mean[1];
|
||||
threadgroup float local_sums[SIMD_SIZE];
|
||||
|
||||
float acc = 0;
|
||||
x += gid * size_t(axis_size) + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
@@ -84,13 +85,15 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]],
|
||||
threadgroup float* local_sums [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
threadgroup float local_inv_mean[1];
|
||||
threadgroup float local_sums[SIMD_SIZE];
|
||||
|
||||
float acc = 0;
|
||||
x += gid * size_t(axis_size) + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
@@ -376,8 +379,6 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
@@ -407,8 +408,6 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
|
@@ -2,7 +2,6 @@
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
template <typename T, bool traditional, bool forward>
|
||||
void rope_single_impl(
|
||||
|
@@ -1,945 +1,16 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||
#include "mlx/backend/metal/kernels/sdpa_vector.h"
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short alignment = 1,
|
||||
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
||||
short TCOLS = BCOLS / n_reads,
|
||||
short TROWS = tgp_size / TCOLS>
|
||||
struct BlockLoaderFA {
|
||||
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
||||
STEEL_CONST short vec_size = n_reads;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
struct alignas(alignment * sizeof(T)) ReadVector {
|
||||
uint8_t v[sizeof(T) * vec_size];
|
||||
};
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockLoaderFA(
|
||||
const device T* src_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj) {}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
||||
*((const device ReadVector*)(&src[i * src_ld]));
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||
|
||||
// Skip loading if thread has no valid reads
|
||||
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Use fast thread memory for bound checks
|
||||
bool tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
// Make sure tmp_idx only contains valid indices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
||||
}
|
||||
|
||||
// Read valid indices into tmp_val
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tile_stride;
|
||||
}
|
||||
METAL_FUNC void next(short n) {
|
||||
src += n * tile_stride;
|
||||
}
|
||||
};
|
||||
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned>
|
||||
struct LoopAlignment {};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
short lda_tgp,
|
||||
short ldb_tgp,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct BlockMMAFA {
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TM_stride = 8 * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TN_stride = 8 * WN;
|
||||
|
||||
// Warp tile size along M
|
||||
STEEL_CONST short TM = BM / TM_stride;
|
||||
// Warp tile size along N
|
||||
STEEL_CONST short TN = BN / TN_stride;
|
||||
|
||||
// Strides of A, B along reduction axis
|
||||
STEEL_CONST short simd_stride_a = {
|
||||
transpose_a ? TM_stride : TM_stride * lda_tgp};
|
||||
STEEL_CONST short simd_stride_b = {
|
||||
transpose_b ? TN_stride * ldb_tgp : TN_stride};
|
||||
|
||||
// Jump between elements
|
||||
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
|
||||
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
|
||||
|
||||
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
|
||||
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
|
||||
|
||||
// Simdgroup matrices
|
||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||
|
||||
// Offsets within threadgroup
|
||||
const short tm;
|
||||
const short tn;
|
||||
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
ushort sid;
|
||||
ushort slid;
|
||||
|
||||
short As_offset;
|
||||
short Bs_offset;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockMMAFA(
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||
// Determine thread position in simdgroup matrix
|
||||
short qid = simd_lane_id / 4;
|
||||
slid = simd_lane_id;
|
||||
sid = simd_group_id;
|
||||
|
||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Determine thread and simdgroup offset
|
||||
As_offset =
|
||||
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
|
||||
Bs_offset =
|
||||
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||
// Adjust for simdgroup and thread location
|
||||
As += As_offset;
|
||||
Bs += Bs_offset;
|
||||
|
||||
// Iterate over BK in blocks of 8
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short kk = 0; kk < BK; kk += 8) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load elements from threadgroup A as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
Asimd[i].thread_elements()[0] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + 0]);
|
||||
Asimd[i].thread_elements()[1] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load elements from threadgroup B as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
Bsimd[j].thread_elements()[0] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
|
||||
Bsimd[j].thread_elements()[1] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
short j_serp = (i % 2) ? (TN - 1 - j) : j;
|
||||
|
||||
simdgroup_multiply_accumulate(
|
||||
results[i * TN + j_serp],
|
||||
Asimd[i],
|
||||
Bsimd[j_serp],
|
||||
results[i * TN + j_serp]);
|
||||
}
|
||||
}
|
||||
|
||||
// Progress to next simdgroup tile
|
||||
As += tile_stride_a;
|
||||
Bs += tile_stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void rescale_output(const threadgroup float* Corrections) {
|
||||
// Loop over all simdgroup tiles
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
short row = sm + tm + i * TM_stride;
|
||||
float scale_value = Corrections[row];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = results[i * TN + j].thread_elements();
|
||||
// int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
accum[0] *= scale_value;
|
||||
accum[1] *= scale_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device U* C, const int ldc) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + tn + sn;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
|
||||
|
||||
// Write out C
|
||||
C[offset] = outs[0];
|
||||
C[offset + 1] = outs[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_result_to_tgp_memory(
|
||||
threadgroup U* C,
|
||||
const int ldc,
|
||||
short2 dst_tile_dims) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn);
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
C[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn);
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
C[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
U outs[2] = {
|
||||
epilogue_op.apply(accum[0], C[offset_c]),
|
||||
epilogue_op.apply(accum[1], C[offset_c + fdc])};
|
||||
|
||||
// Write out D
|
||||
D[offset_d] = outs[0];
|
||||
D[offset_d + 1] = outs[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_result_safe(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void clear_results() {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
results[i * TN + j] = simdgroup_matrix<AccumType, 8, 8>(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_q,
|
||||
bool transpose_k,
|
||||
bool transpose_v,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct FastAttentionKernel {
|
||||
STEEL_CONST short tgp_padding = 16 / sizeof(T);
|
||||
STEEL_CONST short float_padding = 16 / sizeof(float);
|
||||
STEEL_CONST short tgp_mem_size_q =
|
||||
transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding);
|
||||
STEEL_CONST short tgp_mem_size_k =
|
||||
transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
|
||||
STEEL_CONST short tgp_mem_size_v =
|
||||
transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
|
||||
STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding);
|
||||
|
||||
// maxes, rowsums, rescale
|
||||
STEEL_CONST short tgp_mem_size_corrections =
|
||||
4 * (BM * sizeof(float) + float_padding);
|
||||
|
||||
STEEL_CONST bool share_kv_smem = transpose_k != transpose_v;
|
||||
|
||||
STEEL_CONST short tgp_mem_size = share_kv_smem
|
||||
? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
|
||||
tgp_mem_size_corrections
|
||||
: tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
|
||||
tgp_mem_size_corrections + tgp_mem_size_v;
|
||||
|
||||
STEEL_CONST short tgp_size = WM * WN * 32;
|
||||
|
||||
static_assert(transpose_q == false, "Expected Q not transposed.");
|
||||
static_assert(transpose_k == true, "Expected K transposed.");
|
||||
static_assert(transpose_v == false, "Expected V not transposed.");
|
||||
static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested.");
|
||||
|
||||
using loader_q_t = BlockLoaderFA<
|
||||
T,
|
||||
transpose_q ? BK : BM,
|
||||
transpose_q ? BM : BK,
|
||||
transpose_q ? BM + tgp_padding : BK + tgp_padding,
|
||||
!transpose_q,
|
||||
tgp_size>;
|
||||
|
||||
using loader_k_t = BlockLoaderFA<
|
||||
T,
|
||||
transpose_k ? BN : BK,
|
||||
transpose_k ? BK : BN,
|
||||
transpose_k ? BK + tgp_padding : BN + tgp_padding,
|
||||
transpose_k,
|
||||
tgp_size>;
|
||||
|
||||
using loader_v_t = BlockLoaderFA<
|
||||
T,
|
||||
transpose_v ? BK : BN,
|
||||
transpose_v ? BN : BK,
|
||||
transpose_v ? BN + tgp_padding : BK + tgp_padding,
|
||||
transpose_v,
|
||||
tgp_size>;
|
||||
|
||||
using mma_qk_t = BlockMMAFA<
|
||||
T,
|
||||
U,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_q,
|
||||
transpose_k,
|
||||
transpose_q ? BM + tgp_padding : BK + tgp_padding,
|
||||
transpose_k ? BK + tgp_padding : BN + tgp_padding,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
using mma_sv_t = BlockMMAFA<
|
||||
T,
|
||||
U,
|
||||
BM,
|
||||
BK,
|
||||
BN,
|
||||
WM,
|
||||
WN,
|
||||
false,
|
||||
transpose_v,
|
||||
BN + tgp_padding,
|
||||
BK + tgp_padding,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
/* Main kernel function */
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned_>
|
||||
static METAL_FUNC void gemm_loop(
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
const int gemm_k_iterations,
|
||||
thread loader_k_t& loader_b,
|
||||
thread mma_qk_t& mma_op,
|
||||
thread const short& tgp_bm,
|
||||
thread const short& tgp_bn,
|
||||
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
|
||||
// Appease the compiler
|
||||
(void)l;
|
||||
(void)tgp_bm;
|
||||
|
||||
short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
// not valid for gemm_k_iterations > 1 (so, BK == d_k)
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC void initialize_corrections(
|
||||
threadgroup float* C,
|
||||
uint simd_lane_id,
|
||||
uint simd_group_id) {
|
||||
if (simd_group_id == 0) {
|
||||
threadgroup float* maxes = C;
|
||||
threadgroup float* sums = C + (BM + float_padding);
|
||||
threadgroup float* o_rescale = sums + (BM + float_padding);
|
||||
threadgroup float* output_rescale = o_rescale + (BM + float_padding);
|
||||
|
||||
if (simd_lane_id < BM) {
|
||||
maxes[simd_lane_id] = -INFINITY; // m_i
|
||||
sums[simd_lane_id] = 0.f; // l_i
|
||||
o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new)
|
||||
output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC void rescale_ss(
|
||||
threadgroup T* Ss,
|
||||
threadgroup float* Corrections,
|
||||
uint simd_group_id,
|
||||
uint simd_lane_id,
|
||||
short2 local_blocks,
|
||||
float alpha) {
|
||||
if (simd_group_id == 0) {
|
||||
short row_offset = BM + float_padding;
|
||||
threadgroup float* maxes = Corrections;
|
||||
threadgroup float* sums = Corrections + row_offset;
|
||||
threadgroup float* o_rescale = sums + row_offset;
|
||||
threadgroup float* output_scales = o_rescale + row_offset;
|
||||
|
||||
if (simd_lane_id < uint(local_blocks.y)) {
|
||||
float m_i_old = maxes[simd_lane_id];
|
||||
float l_i_old = sums[simd_lane_id];
|
||||
|
||||
float m_i_new = m_i_old;
|
||||
float l_i_new = l_i_old;
|
||||
|
||||
short offset = simd_lane_id * (BN + tgp_padding);
|
||||
|
||||
float m_ij = -INFINITY;
|
||||
|
||||
for (short j = 0; j < local_blocks.x; j++) {
|
||||
float val = alpha * float(Ss[offset + j]);
|
||||
m_ij = max(m_ij, val);
|
||||
}
|
||||
|
||||
m_i_new = max(m_ij, m_i_new);
|
||||
|
||||
float rowsum = 0.f; // lij
|
||||
|
||||
for (short j = 0; j < local_blocks.x; j++) {
|
||||
float val = alpha * float(Ss[offset + j]);
|
||||
float P_i_j = exp(val - m_ij);
|
||||
rowsum += P_i_j;
|
||||
P_i_j = P_i_j * exp(m_ij - m_i_new);
|
||||
Ss[offset + j] = T(P_i_j);
|
||||
}
|
||||
|
||||
l_i_new =
|
||||
exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum;
|
||||
maxes[simd_lane_id] = m_i_new;
|
||||
sums[simd_lane_id] = l_i_new;
|
||||
float rescale = l_i_old * exp(m_i_old - m_i_new);
|
||||
o_rescale[simd_lane_id] = rescale;
|
||||
output_scales[simd_lane_id] = 1.0 / l_i_new;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Main kernel function */
|
||||
static METAL_FUNC void run(
|
||||
const device T* Q [[buffer(0)]],
|
||||
const device T* K [[buffer(1)]],
|
||||
const device T* V [[buffer(2)]],
|
||||
device U* O [[buffer(3)]],
|
||||
const constant MLXFastAttentionParams* params [[buffer(4)]],
|
||||
threadgroup T* Qs [[threadgroup(0)]],
|
||||
threadgroup T* Ks [[threadgroup(1)]],
|
||||
threadgroup T* Ss [[threadgroup(2)]],
|
||||
threadgroup T* Vs [[threadgroup(3)]],
|
||||
threadgroup float* Corrections [[threadgroup(4)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in Q, O; and head in K, V.
|
||||
const int c_row = tid_y * BM;
|
||||
|
||||
Q += transpose_q ? c_row : c_row * params->ldq;
|
||||
thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id);
|
||||
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
|
||||
loader_q.load_safe(tile_dims_Q);
|
||||
|
||||
initialize_corrections(Corrections, simd_lane_id, simd_group_id);
|
||||
|
||||
O += c_row * params->ldo;
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id);
|
||||
thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id);
|
||||
thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id);
|
||||
thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id);
|
||||
|
||||
for (short n_block = 0; n_block < params->gemm_n_iterations_aligned;
|
||||
n_block++) {
|
||||
short c_col = BN;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
short gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
short tgp_bn_qk = min(BN, params->N - c_col * n_block);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
{ // Loop over K - unaligned case
|
||||
|
||||
if (tgp_bm == BM && tgp_bn_qk == BN) {
|
||||
gemm_loop<true, true, K_aligned>(
|
||||
Qs,
|
||||
Ks,
|
||||
gemm_k_iterations,
|
||||
loader_k,
|
||||
mma_qk_op,
|
||||
tgp_bm,
|
||||
tgp_bn_qk);
|
||||
} else if (tgp_bn_qk == BN) {
|
||||
gemm_loop<false, true, K_aligned>(
|
||||
Qs,
|
||||
Ks,
|
||||
gemm_k_iterations,
|
||||
loader_k,
|
||||
mma_qk_op,
|
||||
tgp_bm,
|
||||
tgp_bn_qk);
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_loop<true, false, K_aligned>(
|
||||
Qs,
|
||||
Ks,
|
||||
gemm_k_iterations,
|
||||
loader_k,
|
||||
mma_qk_op,
|
||||
tgp_bm,
|
||||
tgp_bn_qk);
|
||||
|
||||
} else {
|
||||
gemm_loop<false, false, K_aligned>(
|
||||
Qs,
|
||||
Ks,
|
||||
gemm_k_iterations,
|
||||
loader_k,
|
||||
mma_qk_op,
|
||||
tgp_bm,
|
||||
tgp_bn_qk);
|
||||
}
|
||||
}
|
||||
|
||||
mma_qk_op.store_result_to_tgp_memory(
|
||||
Ss, BN + tgp_padding, short2(BN, BM));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
rescale_ss(
|
||||
Ss,
|
||||
Corrections,
|
||||
simd_group_id,
|
||||
simd_lane_id,
|
||||
short2(tgp_bn_qk, tgp_bm),
|
||||
params->alpha);
|
||||
|
||||
loader_v.load_safe(short2(BK, tgp_bn_qk));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
threadgroup float* o_scales = Corrections + 2 * (BM + float_padding);
|
||||
mma_softmax_sv_op.rescale_output(o_scales);
|
||||
|
||||
mma_softmax_sv_op.mma(Ss, Vs);
|
||||
|
||||
threadgroup float* final_output_scales =
|
||||
Corrections + 3 * (BM + float_padding);
|
||||
|
||||
mma_softmax_sv_op.rescale_output(final_output_scales);
|
||||
|
||||
loader_v.next();
|
||||
loader_k.next(BN);
|
||||
|
||||
mma_qk_op.clear_results();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm));
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_q,
|
||||
bool transpose_k,
|
||||
bool transpose_v,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention(
|
||||
const device T* Q [[buffer(0)]],
|
||||
const device T* K [[buffer(1)]],
|
||||
const device T* V [[buffer(2)]],
|
||||
device T* O [[buffer(3)]],
|
||||
const constant MLXFastAttentionParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using attention_kernel = FastAttentionKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_q,
|
||||
transpose_k,
|
||||
transpose_v,
|
||||
MN_aligned,
|
||||
K_aligned>;
|
||||
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* Q_bstrides = batch_strides;
|
||||
const constant size_t* KV_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim);
|
||||
|
||||
Q += batch_offsets.x;
|
||||
K += batch_offsets.y;
|
||||
V += batch_offsets.y;
|
||||
|
||||
} else {
|
||||
Q += params->batch_stride_q * tid.z;
|
||||
K += params->batch_stride_k * tid.z;
|
||||
V += params->batch_stride_v * tid.z;
|
||||
}
|
||||
|
||||
// same shape as input
|
||||
O += params->batch_stride_o * tid.z;
|
||||
threadgroup T Qs[attention_kernel::tgp_mem_size_q];
|
||||
threadgroup T Ss[attention_kernel::tgp_mem_size_s];
|
||||
threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections];
|
||||
|
||||
if (attention_kernel::share_kv_smem) {
|
||||
threadgroup T Ks[attention_kernel::tgp_mem_size_k];
|
||||
threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v];
|
||||
attention_kernel::run(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
O,
|
||||
params,
|
||||
Qs,
|
||||
Ks,
|
||||
Ss,
|
||||
Vs,
|
||||
Corrections,
|
||||
simd_lane_id,
|
||||
simd_group_id,
|
||||
tid,
|
||||
lid);
|
||||
} else {
|
||||
threadgroup T Ks[attention_kernel::tgp_mem_size_k];
|
||||
threadgroup T Vs[attention_kernel::tgp_mem_size_v];
|
||||
attention_kernel::run(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
O,
|
||||
params,
|
||||
Qs,
|
||||
Ks,
|
||||
Ss,
|
||||
Vs,
|
||||
Corrections,
|
||||
simd_lane_id,
|
||||
simd_group_id,
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
|
||||
// SDPA full instantiations
|
||||
#define instantiate_fast_inference_self_attention_kernel( \
|
||||
itype, otype, bm, bn, bk, wm, wn) \
|
||||
template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \
|
||||
"_itype_" #itype)]] [[kernel]] void \
|
||||
attention<itype, bm, bn, bk, wm, wn, false, true, false, false, true>( \
|
||||
const device itype* Q [[buffer(0)]], \
|
||||
const device itype* K [[buffer(1)]], \
|
||||
const device itype* V [[buffer(2)]], \
|
||||
device otype* O [[buffer(3)]], \
|
||||
const constant MLXFastAttentionParams* params [[buffer(4)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
instantiate_fast_inference_self_attention_kernel(
|
||||
float,
|
||||
float,
|
||||
16,
|
||||
16,
|
||||
64,
|
||||
2,
|
||||
2);
|
||||
instantiate_fast_inference_self_attention_kernel(
|
||||
float,
|
||||
float,
|
||||
16,
|
||||
16,
|
||||
128,
|
||||
2,
|
||||
2);
|
||||
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2);
|
||||
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
|
||||
|
||||
// SDPA vector instantiations
|
||||
#define instantiate_sdpa_vector(type, head_dim) \
|
||||
template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \
|
||||
[[kernel]] void sdpa_vector<type, head_dim>( \
|
||||
const device type* queries [[buffer(0)]], \
|
||||
const device type* keys [[buffer(1)]], \
|
||||
const device type* values [[buffer(2)]], \
|
||||
device type* out [[buffer(3)]], \
|
||||
const constant int& gqa_factor, \
|
||||
const constant int& N, \
|
||||
const constant size_t& k_stride, \
|
||||
const constant float& scale, \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
#define instantiate_sdpa_vector(type, head_dim) \
|
||||
instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) \
|
||||
instantiate_kernel("sdpa_vector_2pass_1_" #type "_" #head_dim, sdpa_vector_2pass_1, type, head_dim) \
|
||||
instantiate_kernel("sdpa_vector_2pass_2_" #type "_" #head_dim, sdpa_vector_2pass_2, type, head_dim)
|
||||
|
||||
#define instantiate_sdpa_vector_heads(type) \
|
||||
instantiate_sdpa_vector(type, 64) \
|
||||
|
@@ -1,42 +0,0 @@
|
||||
//
|
||||
// scaled_dot_product_attention_params.h
|
||||
// mlx
|
||||
|
||||
#pragma once
|
||||
|
||||
struct MLXFastAttentionParams {
|
||||
const int M;
|
||||
const int N;
|
||||
const int K;
|
||||
|
||||
const int ldq; // ldq == ldo
|
||||
const int ldk;
|
||||
const int ldv;
|
||||
const int lds;
|
||||
const int ldo;
|
||||
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const int batch_stride_q;
|
||||
const int batch_stride_k;
|
||||
const int batch_stride_v;
|
||||
const int batch_stride_o;
|
||||
|
||||
const int swizzle_log;
|
||||
const int gemm_n_iterations_aligned;
|
||||
const int gemm_k_iterations_aligned;
|
||||
const int gemm_sv_m_block_iterations;
|
||||
|
||||
const int batch_ndim;
|
||||
const float alpha;
|
||||
};
|
||||
|
||||
struct MLXScaledDotProductAttentionParams {
|
||||
// Associated dimensions & transposition information
|
||||
const uint QUERY_SEQUENCE_LENGTH = 1;
|
||||
const uint N_Q_HEADS = 32;
|
||||
const uint N_KV_HEADS = 32;
|
||||
const uint KV_TILES = 1;
|
||||
const float INV_ALPHA = 0.08838834764831843f;
|
||||
};
|
@@ -4,73 +4,57 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
METAL_FUNC void scatter_1d_index_impl(
|
||||
const device T* updates [[buffer(1)]],
|
||||
device mlx_atomic<T>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& out_ndim [[buffer(5)]],
|
||||
const constant int* upd_shape [[buffer(6)]],
|
||||
const constant size_t& upd_ndim [[buffer(7)]],
|
||||
const constant size_t& upd_size [[buffer(8)]],
|
||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; i++) {
|
||||
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
|
||||
out_idx += idx_val * out_strides[i];
|
||||
}
|
||||
|
||||
if (upd_ndim > 1) {
|
||||
auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim);
|
||||
out_idx += out_offset;
|
||||
} else {
|
||||
out_idx += gid.x;
|
||||
}
|
||||
|
||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx);
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename Op,
|
||||
int NIDX,
|
||||
bool UPD_ROW_CONTIG,
|
||||
int NWORK,
|
||||
typename LocT>
|
||||
METAL_FUNC void scatter_impl(
|
||||
const device T* updates [[buffer(1)]],
|
||||
device mlx_atomic<T>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
const constant size_t* upd_strides [[buffer(4)]],
|
||||
const constant size_t& upd_ndim [[buffer(5)]],
|
||||
const constant size_t& upd_size [[buffer(6)]],
|
||||
const constant int* out_shape [[buffer(7)]],
|
||||
const constant size_t* out_strides [[buffer(8)]],
|
||||
const constant size_t& out_ndim [[buffer(9)]],
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const device T* updates,
|
||||
device mlx_atomic<T>* out,
|
||||
const constant int* upd_shape,
|
||||
const constant size_t* upd_strides,
|
||||
const constant size_t& upd_ndim,
|
||||
const constant size_t& upd_size,
|
||||
const constant int* out_shape,
|
||||
const constant size_t* out_strides,
|
||||
const constant size_t& out_ndim,
|
||||
const constant int* axes,
|
||||
const constant size_t& idx_size,
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
auto ind_idx = gid.y;
|
||||
auto ind_offset = gid.x;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
|
||||
auto ind_idx = gid.y * NWORK;
|
||||
LocT out_offset = 0;
|
||||
if (upd_size > 1) {
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
out_idx += out_offset;
|
||||
out_offset = elem_to_loc<size_t, LocT>(
|
||||
gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
}
|
||||
|
||||
auto upd_idx =
|
||||
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
|
||||
LocT out_idx = out_offset;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = indices.row_contiguous[i]
|
||||
? ind_idx
|
||||
: elem_to_loc<size_t, LocT>(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx +=
|
||||
static_cast<LocT>(idx_val) * static_cast<LocT>(out_strides[ax]);
|
||||
}
|
||||
auto upd_idx = ind_idx * static_cast<LocT>(upd_size) + gid.x;
|
||||
if constexpr (!UPD_ROW_CONTIG) {
|
||||
upd_idx =
|
||||
elem_to_loc<size_t, LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
|
||||
}
|
||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||
}
|
||||
}
|
||||
|
@@ -13,6 +13,7 @@ template <typename T, int D>
|
||||
const constant int& gqa_factor,
|
||||
const constant int& N,
|
||||
const constant size_t& k_stride,
|
||||
const constant size_t& v_stride,
|
||||
const constant float& scale,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -20,8 +21,7 @@ template <typename T, int D>
|
||||
constexpr int BN = 32;
|
||||
constexpr int BD = 32;
|
||||
constexpr int elem_per_thread = D / BD;
|
||||
|
||||
const int stride = BN * D;
|
||||
constexpr int stride = BN * D;
|
||||
|
||||
typedef float U;
|
||||
|
||||
@@ -38,7 +38,7 @@ template <typename T, int D>
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
queries += head_idx * D + simd_lid * elem_per_thread;
|
||||
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||
values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||
out += head_idx * D + simd_gid * elem_per_thread;
|
||||
|
||||
// Read the query and 0 the output accumulator
|
||||
@@ -83,7 +83,6 @@ template <typename T, int D>
|
||||
keys += stride;
|
||||
values += stride;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Each thread has a partial part of the output so we need to combine them.
|
||||
|
||||
@@ -113,3 +112,181 @@ template <typename T, int D>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int D>
|
||||
[[kernel]] void sdpa_vector_2pass_1(
|
||||
const device T* queries [[buffer(0)]],
|
||||
const device T* keys [[buffer(1)]],
|
||||
const device T* values [[buffer(2)]],
|
||||
device float* out [[buffer(3)]],
|
||||
device float* sums [[buffer(4)]],
|
||||
device float* maxs [[buffer(5)]],
|
||||
const constant int& gqa_factor,
|
||||
const constant int& N,
|
||||
const constant size_t& k_stride,
|
||||
const constant size_t& v_stride,
|
||||
const constant float& scale,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int BN = 8;
|
||||
constexpr int BD = 32;
|
||||
constexpr int elem_per_thread = D / BD;
|
||||
constexpr int stride = BN * D;
|
||||
constexpr int blocks = 32;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U q[elem_per_thread];
|
||||
thread U k[elem_per_thread];
|
||||
thread U o[elem_per_thread];
|
||||
|
||||
threadgroup U outputs[BN * BD];
|
||||
threadgroup U max_scores[BN];
|
||||
threadgroup U sum_exp_scores[BN];
|
||||
|
||||
// Adjust positions
|
||||
const int block_idx = tid.z;
|
||||
const int head_idx = tid.y;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
queries += head_idx * D + simd_lid * elem_per_thread;
|
||||
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
|
||||
simd_lid * elem_per_thread;
|
||||
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
|
||||
simd_lid * elem_per_thread;
|
||||
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
|
||||
sums += head_idx * blocks + block_idx;
|
||||
maxs += head_idx * blocks + block_idx;
|
||||
|
||||
// Read the query and 0 the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
q[i] = static_cast<U>(scale) * queries[i];
|
||||
}
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = 0;
|
||||
}
|
||||
|
||||
U max_score = -1e9;
|
||||
U sum_exp_score = 0;
|
||||
|
||||
// For each key
|
||||
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
||||
// Read the key
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
k[i] = keys[i];
|
||||
}
|
||||
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
|
||||
// Update the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = o[i] * factor + exp_score * values[i];
|
||||
}
|
||||
|
||||
// Move the pointers to the next kv
|
||||
keys += blocks * stride;
|
||||
values += blocks * stride;
|
||||
}
|
||||
|
||||
// Each thread has a partial part of the output so we need to combine them.
|
||||
|
||||
// First let's communicate the max and sum_exp
|
||||
if (simd_lid == 0) {
|
||||
max_scores[simd_gid] = max_score;
|
||||
sum_exp_scores[simd_gid] = sum_exp_score;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
|
||||
U new_max = simd_max(max_score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
|
||||
sum_exp_score = simd_sum(sum_exp_score * factor);
|
||||
|
||||
// Write the sum and new max
|
||||
if (simd_gid == 0) {
|
||||
sums[0] = sum_exp_score;
|
||||
maxs[0] = new_max;
|
||||
}
|
||||
|
||||
// Now we need to aggregate all the outputs
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
outputs[simd_lid * BN + simd_gid] =
|
||||
o[i] * fast::exp(max_scores[simd_gid] - new_max);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// And write the output
|
||||
if (simd_gid == 0) {
|
||||
U output = outputs[simd_lid * BN];
|
||||
for (int j = 1; j < BN; j++) {
|
||||
output += outputs[simd_lid * BN + j];
|
||||
}
|
||||
out[i] = static_cast<T>(output);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int D>
|
||||
[[kernel]] void sdpa_vector_2pass_2(
|
||||
const device float* partials [[buffer(0)]],
|
||||
const device float* sums [[buffer(1)]],
|
||||
const device float* maxs [[buffer(2)]],
|
||||
device T* out [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int BN = 32;
|
||||
constexpr int BD = 32;
|
||||
constexpr int elem_per_thread = D / BD;
|
||||
constexpr int blocks = 32;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U o[elem_per_thread];
|
||||
threadgroup U outputs[BN * BD];
|
||||
|
||||
// Adjust positions
|
||||
const int head_idx = tid.y;
|
||||
partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
|
||||
sums += head_idx * blocks;
|
||||
maxs += head_idx * blocks;
|
||||
out += head_idx * D + simd_gid * elem_per_thread;
|
||||
|
||||
// First everybody reads the max and sum_exp
|
||||
U max_score = maxs[simd_lid];
|
||||
U new_max = simd_max(max_score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U sum_exp_score = simd_sum(sums[simd_lid] * factor);
|
||||
|
||||
// Now read the block into registers and then use shared memory to transpose
|
||||
// it
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = partials[i];
|
||||
}
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
outputs[simd_lid * BD + simd_gid] = o[i];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
// And write the output
|
||||
if (simd_lid == 0) {
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
out[i] = static_cast<T>(o[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -6,8 +6,6 @@
|
||||
using namespace metal;
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/softmax.h"
|
||||
|
||||
|
@@ -3,8 +3,6 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/sort.h"
|
||||
|
||||
|
296
mlx/backend/metal/kernels/steel/attn/attn.h
Normal file
296
mlx/backend/metal/kernels/steel/attn/attn.h
Normal file
@@ -0,0 +1,296 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/loader.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/mma.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel class
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned>
|
||||
struct LoopAlignment {};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct GEMMKernel {
|
||||
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||
STEEL_CONST short tgp_mem_size_a =
|
||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||
STEEL_CONST short tgp_mem_size_b =
|
||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||
|
||||
STEEL_CONST short tgp_size = WM * WN * 32;
|
||||
|
||||
using loader_a_t = BlockLoader<
|
||||
T,
|
||||
transpose_a ? BK : BM,
|
||||
transpose_a ? BM : BK,
|
||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||
!transpose_a,
|
||||
tgp_size>;
|
||||
using loader_b_t = BlockLoader<
|
||||
T,
|
||||
transpose_b ? BN : BK,
|
||||
transpose_b ? BK : BN,
|
||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||
transpose_b,
|
||||
tgp_size>;
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
U,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
/* Main kernel function */
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned_>
|
||||
static METAL_FUNC void gemm_loop(
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
const int gemm_k_iterations,
|
||||
thread loader_a_t& loader_a,
|
||||
thread loader_b_t& loader_b,
|
||||
thread mma_t& mma_op,
|
||||
thread const short& tgp_bm,
|
||||
thread const short& tgp_bn,
|
||||
thread const short& lbk,
|
||||
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
|
||||
// Appease the compiler
|
||||
(void)l;
|
||||
|
||||
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
|
||||
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
if (M_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
}
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
if (!K_aligned_) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
short2 tile_dims_A_last =
|
||||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
||||
short2 tile_dims_B_last =
|
||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A_last);
|
||||
loader_b.load_safe(tile_dims_B_last);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
}
|
||||
|
||||
/* Main kernel function */
|
||||
static METAL_FUNC void run(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device U* D [[buffer(2)]],
|
||||
const constant GEMMParams* params [[buffer(3)]],
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
D += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
if (tgp_bm == BM && tgp_bn == BN) {
|
||||
gemm_loop<true, true, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_loop<false, true, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_loop<true, false, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else {
|
||||
gemm_loop<false, false, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
349
mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h
Normal file
349
mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h
Normal file
@@ -0,0 +1,349 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constant bool align_Q [[function_constant(200)]];
|
||||
constant bool align_K [[function_constant(201)]];
|
||||
|
||||
template <typename T>
|
||||
struct TransformScale {
|
||||
T scale;
|
||||
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
|
||||
|
||||
METAL_FUNC T apply(T x) const {
|
||||
return scale * x;
|
||||
}
|
||||
};
|
||||
|
||||
struct MaxOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct SumOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct MulOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct SubOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct ExpSubOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return fast::exp(x - y);
|
||||
}
|
||||
};
|
||||
|
||||
struct DivOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
typename T,
|
||||
int BQ,
|
||||
int BK,
|
||||
int BD,
|
||||
int WM,
|
||||
int WN,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
|
||||
const device T* Q [[buffer(0)]],
|
||||
const device T* K [[buffer(1)]],
|
||||
const device T* V [[buffer(2)]],
|
||||
device T* O [[buffer(3)]],
|
||||
const constant AttnParams* params [[buffer(4)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
||||
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
// Move to correct block
|
||||
ulong3 tidl{tid.x, tid.y, tid.z};
|
||||
|
||||
Q += tidl.z * params->Q_strides[0] + // Batch
|
||||
tidl.y * params->Q_strides[1] + // Head
|
||||
tidl.x * BQ * params->Q_strides[2]; // Seqeunce
|
||||
|
||||
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
|
||||
K += tidl.z * params->K_strides[0] + // Batch
|
||||
kv_head_idx * params->K_strides[1]; // Head
|
||||
|
||||
V += tidl.z * params->V_strides[0] + // Batch
|
||||
kv_head_idx * params->V_strides[1]; // Head
|
||||
|
||||
O += tidl.z * params->O_strides[0] + // Batch
|
||||
tidl.y * params->O_strides[1] + // Head
|
||||
tidl.x * BQ * params->O_strides[2]; // Seqeunce
|
||||
|
||||
// Prepare threadgroup memory
|
||||
constexpr short padQ = 0; // 16 / sizeof(T);
|
||||
constexpr short padK = 0; // 16 / sizeof(T);
|
||||
constexpr short padV = 0; // 16 / sizeof(T);
|
||||
|
||||
constexpr short LDQ_tgp = BD + padQ;
|
||||
constexpr short LDK_tgp = BK + padK;
|
||||
constexpr short LDV_tgp = BD + padV;
|
||||
|
||||
threadgroup T Qs[BQ * (BD + padQ)];
|
||||
threadgroup T Ks[(BK + padK) * BD];
|
||||
threadgroup T Vs[BK * (BD + padV)];
|
||||
|
||||
// Prepare block loaders
|
||||
using QBlockLoader = BlockLoaderT<
|
||||
/* typename T = */ T,
|
||||
/* short BROWS = */ BQ,
|
||||
/* short BCOLS = */ BD,
|
||||
/* short kDstStrRow = */ LDQ_tgp,
|
||||
/* short kDstStrCol = */ 1,
|
||||
/* short reduction_dim = */ 1,
|
||||
/* short tgp_size = */ WM * WN * 32>;
|
||||
|
||||
// K is loaded in transposed
|
||||
using KBlockLoader = BlockLoaderT<
|
||||
/* typename T = */ T,
|
||||
/* short BROWS = */ BK,
|
||||
/* short BCOLS = */ BD,
|
||||
/* short kDstStrRow = */ 1,
|
||||
/* short kDstStrCol = */ LDK_tgp,
|
||||
/* short reduction_dim = */ 0,
|
||||
/* short tgp_size = */ WM * WN * 32>;
|
||||
|
||||
using VBlockLoader = BlockLoaderT<
|
||||
/* typename T = */ T,
|
||||
/* short BROWS = */ BK,
|
||||
/* short BCOLS = */ BD,
|
||||
/* short kDstStrRow = */ LDV_tgp,
|
||||
/* short kDstStrCol = */ 1,
|
||||
/* short reduction_dim = */ 0,
|
||||
/* short tgp_size = */ WM * WN * 32>;
|
||||
|
||||
QBlockLoader loader_q(
|
||||
Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
|
||||
KBlockLoader loader_k(
|
||||
K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
|
||||
VBlockLoader loader_v(
|
||||
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
|
||||
|
||||
TransformScale<T> ts(static_cast<T>(params->scale));
|
||||
|
||||
// Prepare MMA tiles
|
||||
constexpr short kFragSize = 8; // MMAFrag size
|
||||
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
||||
|
||||
constexpr int kNWarps = WM * WN;
|
||||
static_assert(
|
||||
BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
|
||||
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
|
||||
|
||||
// Q seq frags per warp
|
||||
constexpr int TQ = BQ / (kNWarps * kFragSize);
|
||||
// KV sequence frags (all warps load the same frags)
|
||||
constexpr int TK = BK / kFragSize;
|
||||
// HeadDim frags (all warps load the same frags)
|
||||
constexpr int TD = BD / kFragSize;
|
||||
|
||||
static_assert(TQ == 1, "Check TQ");
|
||||
|
||||
MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;
|
||||
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;
|
||||
MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;
|
||||
MMATile<AccumType, TK, TD, MMAFrag_acc_t> Vtile;
|
||||
MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;
|
||||
|
||||
Otile.clear();
|
||||
|
||||
// Prepare mma tile offsets
|
||||
const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
||||
const short sm = simd_coord.y;
|
||||
const short sn = simd_coord.x;
|
||||
const short tm = kFragSize * TQ * simd_group_id;
|
||||
|
||||
const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
|
||||
const short Ks_offset = sm * LDK_tgp + sn;
|
||||
const short Vs_offset = sm * LDV_tgp + sn;
|
||||
|
||||
constexpr short Qs_tile_stride = kFragSize;
|
||||
constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load Q blocks apply scale
|
||||
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
||||
loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
|
||||
} else {
|
||||
loader_q.load_unsafe();
|
||||
}
|
||||
loader_q.apply_inplace_op(ts);
|
||||
|
||||
// Init row reduction variables
|
||||
constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
|
||||
|
||||
AccumType max_score[kRowsPT];
|
||||
AccumType sum_score[kRowsPT] = {0};
|
||||
|
||||
// Init to -Inf
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = Limits<AccumType>::min;
|
||||
}
|
||||
|
||||
// Loop over KV seq length
|
||||
for (int kb = 0; kb < params->NK; kb++) {
|
||||
// Load K block and apply scale
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (!align_K && kb == (params->NK_aligned)) {
|
||||
loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
|
||||
} else {
|
||||
loader_k.load_unsafe();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Do S = Q @ K.T
|
||||
Stile.clear();
|
||||
|
||||
for (short dd = 0; dd < TD; dd++) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
|
||||
&Qs[Qs_offset + dd * Qs_tile_stride]);
|
||||
Ktile.template load<T, 1, 1, LDK_tgp, 1>(
|
||||
&Ks[Ks_offset + dd * Ks_tile_stride]);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
tile_matmad(Stile, Qtile, Ktile, Stile);
|
||||
}
|
||||
|
||||
// Mask out of length sequence
|
||||
if (!align_K && kb == (params->NK_aligned)) {
|
||||
using stile_t = decltype(Stile);
|
||||
using selem_t = typename stile_t::elem_type;
|
||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||
const short lim = params->kL - params->NK_aligned * BK;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < stile_t::kTileCols; j++) {
|
||||
short col_pos = sn + (j * stile_t::kFragCols);
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
|
||||
if ((col_pos + jj) >= lim) {
|
||||
Stile.frag_at(i, j)[jj] = neg_inf;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load V blocks
|
||||
if (!align_K && kb == (params->NK_aligned)) {
|
||||
loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
|
||||
} else {
|
||||
loader_v.load_unsafe();
|
||||
}
|
||||
|
||||
// Do softmax
|
||||
|
||||
// Temp variables
|
||||
AccumType new_max[kRowsPT];
|
||||
AccumType factor[kRowsPT];
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
new_max[i] = max_score[i];
|
||||
}
|
||||
|
||||
// Row max
|
||||
Stile.template row_reduce<MaxOp>(new_max);
|
||||
|
||||
// exp(Si - rowmax(Si))
|
||||
Stile.template row_bin_op<ExpSubOp>(new_max);
|
||||
|
||||
// Factor exp(rowmax(Si) - rowmax(Si-1))
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
factor[i] = fast::exp(max_score[i] - new_max[i]);
|
||||
}
|
||||
|
||||
// Save max for next iteration
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = new_max[i];
|
||||
}
|
||||
|
||||
// Row Sum
|
||||
AccumType sum_score_tmp[kRowsPT] = {0};
|
||||
Stile.template row_reduce<SumOp>(sum_score_tmp);
|
||||
|
||||
// Update norm
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
|
||||
}
|
||||
|
||||
// Update O
|
||||
Otile.template row_bin_op<MulOp>(factor);
|
||||
|
||||
// Load V into registers
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Do O = S @ V
|
||||
tile_matmad(Otile, Stile, Vtile, Otile);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_k.next();
|
||||
loader_v.next();
|
||||
}
|
||||
|
||||
// Normalize output
|
||||
Otile.template row_bin_op<DivOp>(sum_score);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results
|
||||
O += (tm + sm) * params->O_strides[2] + sn;
|
||||
|
||||
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
||||
auto dst_tile_dims =
|
||||
short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
|
||||
} else {
|
||||
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
|
||||
}
|
||||
}
|
@@ -0,0 +1,31 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
|
||||
|
||||
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
|
||||
template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
|
||||
[[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
|
||||
const device dtype* Q [[buffer(0)]], \
|
||||
const device dtype* K [[buffer(1)]], \
|
||||
const device dtype* V [[buffer(2)]], \
|
||||
device dtype* O [[buffer(3)]],\
|
||||
const constant AttnParams* params [[buffer(4)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_attn_shapes_helper(iname, itype) \
|
||||
instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
|
||||
|
||||
instantiate_attn_shapes_helper(float16, half);
|
||||
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_attn_shapes_helper(float32, float);
|
||||
// clang-format on
|
264
mlx/backend/metal/kernels/steel/attn/loader.h
Normal file
264
mlx/backend/metal/kernels/steel/attn/loader.h
Normal file
@@ -0,0 +1,264 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short alignment = 1,
|
||||
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
||||
short TCOLS = BCOLS / n_reads,
|
||||
short TROWS = tgp_size / TCOLS>
|
||||
struct BlockLoader {
|
||||
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
||||
STEEL_CONST short vec_size = n_reads;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
struct alignas(alignment * sizeof(T)) ReadVector {
|
||||
uint8_t v[sizeof(T) * vec_size];
|
||||
};
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockLoader(
|
||||
const device T* src_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj) {}
|
||||
|
||||
/* Apply operation to threadgroup without bound checking */
|
||||
template <typename UnaryOp>
|
||||
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
||||
*((const device ReadVector*)(&src[i * src_ld]));
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||
|
||||
// Skip loading if thread has no valid reads
|
||||
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Use fast thread memory for bound checks
|
||||
bool tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
// Make sure tmp_idx only contains valid indices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
||||
}
|
||||
|
||||
// Read valid indices into tmp_val
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tile_stride;
|
||||
}
|
||||
};
|
||||
|
||||
template <int R, int C>
|
||||
struct CShape {
|
||||
STEEL_CONST int kRows = R;
|
||||
STEEL_CONST int kCols = C;
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short kDstStrRow,
|
||||
short kDstStrCol,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
||||
short TCOLS = BCOLS / n_reads,
|
||||
short TROWS = tgp_size / TCOLS>
|
||||
struct BlockLoaderT {
|
||||
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
||||
STEEL_CONST short vec_size = n_reads;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockLoaderT(
|
||||
const device T* src_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * kDstStrRow + bj * kDstStrCol),
|
||||
src(src_ + bi * src_ld + bj) {}
|
||||
|
||||
/* Apply operation to threadgroup without bound checking */
|
||||
template <typename UnaryOp>
|
||||
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * kDstStrRow + j * kDstStrCol] =
|
||||
op.apply(dst[i * kDstStrRow + j * kDstStrCol]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||
|
||||
// Skip loading if thread has no valid reads
|
||||
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * kDstStrRow + j * kDstStrCol] = T(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Use fast thread memory for bound checks
|
||||
bool tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
// Make sure tmp_idx only contains valid indices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
||||
}
|
||||
|
||||
// Read valid indices into tmp_val
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tile_stride;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
726
mlx/backend/metal/kernels/steel/attn/mma.h
Normal file
726
mlx/backend/metal/kernels/steel/attn/mma.h
Normal file
@@ -0,0 +1,726 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MMA helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <typename RInt, typename CInt>
|
||||
struct Shape2D {
|
||||
RInt r;
|
||||
CInt c;
|
||||
|
||||
Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {}
|
||||
};
|
||||
|
||||
template <typename Shape, typename Layout>
|
||||
struct Layout2D {
|
||||
Shape shape;
|
||||
Layout layout;
|
||||
};
|
||||
|
||||
template <typename T, int kFragRows_, int kFragCols_>
|
||||
struct BaseMMAFrag {
|
||||
static_assert(
|
||||
kFragRows_ == 8,
|
||||
"Only 8 x 8 fragment matrices are currently supported");
|
||||
static_assert(
|
||||
kFragCols_ == 8,
|
||||
"Only 8 x 8 fragment matrices are currently supported");
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BaseMMAFrag<T, 8, 8> {
|
||||
STEEL_CONST int kFragRows = 8;
|
||||
STEEL_CONST int kFragCols = 8;
|
||||
|
||||
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
|
||||
|
||||
STEEL_CONST int kElemRows = 1;
|
||||
STEEL_CONST int kElemCols = 2;
|
||||
|
||||
static_assert(
|
||||
kElemRows * kElemCols == kElemsPerFrag,
|
||||
"MMAFrag shape is not consistent with MMAFrag size");
|
||||
|
||||
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
|
||||
typedef metal::vec<T, kElemsPerFrag> frag_type;
|
||||
typedef metal::vec<T, kElemRows> row_frag_type;
|
||||
typedef metal::vec<T, kElemCols> col_frag_type;
|
||||
|
||||
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
|
||||
[[thread_index_in_simdgroup]]) {
|
||||
const short qid = simd_lane_id / 4;
|
||||
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
|
||||
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
return short2{fn, fm};
|
||||
}
|
||||
|
||||
template <typename SrcPtrType, typename StrX, typename StrY>
|
||||
METAL_FUNC static constexpr void
|
||||
load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename SrcPtrType,
|
||||
typename StrX,
|
||||
typename StrY,
|
||||
typename LimX,
|
||||
typename LimY,
|
||||
typename OffX,
|
||||
typename OffY>
|
||||
METAL_FUNC static constexpr void load_safe(
|
||||
thread frag_type& dst,
|
||||
SrcPtrType src,
|
||||
StrX str_x,
|
||||
StrY str_y,
|
||||
LimX lim_x,
|
||||
LimY lim_y,
|
||||
OffX off_x = Int<0>{},
|
||||
OffY off_y = Int<0>{}) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||
dst[i * kElemCols + j] =
|
||||
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
|
||||
} else {
|
||||
dst[i * kElemCols + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstPtrType, typename StrX, typename StrY>
|
||||
METAL_FUNC static constexpr void
|
||||
store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
|
||||
using U = pointer_element_t<DstPtrType>;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename DstPtrType,
|
||||
typename StrX,
|
||||
typename StrY,
|
||||
typename LimX,
|
||||
typename LimY,
|
||||
typename OffX,
|
||||
typename OffY>
|
||||
METAL_FUNC static constexpr void store_safe(
|
||||
const thread frag_type& src,
|
||||
DstPtrType dst,
|
||||
StrX str_x,
|
||||
StrY str_y,
|
||||
LimX lim_x,
|
||||
LimY lim_y,
|
||||
OffX off_x = Int<0>{},
|
||||
OffY off_y = Int<0>{}) {
|
||||
using U = pointer_element_t<DstPtrType>;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
|
||||
static_cast<U>(src[i * kElemCols + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC static constexpr void mma(
|
||||
thread frag_type& D,
|
||||
thread frag_type& A,
|
||||
thread frag_type& B,
|
||||
thread frag_type& C) {
|
||||
mat_type D_mat;
|
||||
mat_type A_mat;
|
||||
mat_type B_mat;
|
||||
mat_type C_mat;
|
||||
|
||||
reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
|
||||
reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
|
||||
reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
|
||||
|
||||
mma(D_mat, A_mat, B_mat, C_mat);
|
||||
|
||||
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
|
||||
}
|
||||
|
||||
METAL_FUNC static constexpr void mma(
|
||||
thread mat_type& D,
|
||||
thread mat_type& A,
|
||||
thread mat_type& B,
|
||||
thread mat_type& C) {
|
||||
simdgroup_multiply_accumulate(D, A, B, C);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
METAL_FUNC static constexpr void row_reduce(
|
||||
thread const frag_type& inp_vals,
|
||||
thread T* reduced_vals) {
|
||||
T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
|
||||
|
||||
T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
|
||||
qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
|
||||
|
||||
T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
|
||||
sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
|
||||
|
||||
reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
METAL_FUNC static constexpr void row_bin_op(
|
||||
thread frag_type& inp_vals,
|
||||
thread T* row_vals) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
inp_vals[i * kElemCols + j] =
|
||||
Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int kTileRows_,
|
||||
int kTileCols_,
|
||||
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
|
||||
struct MMATile {
|
||||
using MMAFrag_t = MMAFrag_;
|
||||
using elem_type = T;
|
||||
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
|
||||
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
|
||||
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
|
||||
|
||||
STEEL_CONST int kTileRows = kTileRows_;
|
||||
STEEL_CONST int kTileCols = kTileCols_;
|
||||
|
||||
STEEL_CONST int kRows = kTileRows * kFragRows;
|
||||
STEEL_CONST int kCols = kTileCols * kFragCols;
|
||||
|
||||
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
|
||||
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
|
||||
|
||||
STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;
|
||||
STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;
|
||||
|
||||
typedef typename MMAFrag_t::mat_type mat_type;
|
||||
typedef typename MMAFrag_t::frag_type frag_type;
|
||||
|
||||
frag_type val_frags[kNumFrags] = {frag_type(0)};
|
||||
|
||||
METAL_FUNC MMATile() thread {}
|
||||
|
||||
METAL_FUNC constexpr void clear() {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kNumFrags; ++i) {
|
||||
val_frags[i] = frag_type(0);
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
|
||||
return val_frags[i * kTileCols + j];
|
||||
}
|
||||
|
||||
METAL_FUNC constexpr const thread frag_type& frag_at(
|
||||
const short i,
|
||||
const short j) const {
|
||||
return val_frags[i * kTileCols + j];
|
||||
}
|
||||
|
||||
METAL_FUNC mat_type mat_at(const short i, const short j) {
|
||||
mat_type val_mat;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
|
||||
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
|
||||
}
|
||||
return val_mat;
|
||||
}
|
||||
|
||||
METAL_FUNC thread elem_type* elems() {
|
||||
return reinterpret_cast<thread elem_type*>(val_frags);
|
||||
}
|
||||
|
||||
METAL_FUNC const thread elem_type* elems() const {
|
||||
return reinterpret_cast<const thread elem_type*>(val_frags);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::template row_reduce<Op>(
|
||||
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::template row_bin_op<Op>(
|
||||
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
||||
METAL_FUNC void load(const threadgroup U* src) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::load(
|
||||
frag_at(i, j),
|
||||
&(
|
||||
src[(i * kFragRows) * w_x * str_x +
|
||||
(j * kFragCols) * w_y * str_y]),
|
||||
Int<str_x>{},
|
||||
Int<str_y>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
||||
METAL_FUNC void store(threadgroup U* dst) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store(
|
||||
frag_at(i, j),
|
||||
&(
|
||||
dst[(i * kFragRows) * w_x * str_x +
|
||||
(j * kFragCols) * w_y * str_y]),
|
||||
Int<str_x>{},
|
||||
Int<str_y>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void load(const device U* src, const int ld) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::load(
|
||||
frag_at(i, j),
|
||||
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
||||
ld,
|
||||
Int<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void store(device U* dst, const int ld) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store(
|
||||
frag_at(i, j),
|
||||
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
||||
ld,
|
||||
Int<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void
|
||||
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::load_safe(
|
||||
frag_at(i, j),
|
||||
src,
|
||||
ld,
|
||||
Int<1>{},
|
||||
src_tile_dims.y,
|
||||
src_tile_dims.x,
|
||||
(i * kFragRows) * w_x,
|
||||
(j * kFragCols) * w_y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void
|
||||
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store_safe(
|
||||
frag_at(i, j),
|
||||
dst,
|
||||
ld,
|
||||
Int<1>{},
|
||||
dst_tile_dims.y,
|
||||
dst_tile_dims.x,
|
||||
(i * kFragRows) * w_x,
|
||||
(j * kFragCols) * w_y);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int M, int N, int K>
|
||||
METAL_FUNC void tile_matmad(
|
||||
thread MMATile<T, M, N>& D,
|
||||
thread MMATile<U, M, K>& A,
|
||||
thread MMATile<U, K, N>& B,
|
||||
thread MMATile<T, M, N>& C) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < K; ++k) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short m = 0; m < M; ++m) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short n = 0; n < N; ++n) {
|
||||
short n_serp = (m % 2) ? (N - 1 - n) : n;
|
||||
MMATile<T, M, N>::MMAFrag_t::mma(
|
||||
D.frag_at(m, n_serp),
|
||||
A.frag_at(m, k),
|
||||
B.frag_at(k, n_serp),
|
||||
C.frag_at(m, n_serp));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
short lda_tgp,
|
||||
short ldb_tgp,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct BlockMMA {
|
||||
// MMAFrag size
|
||||
STEEL_CONST short kFragSize = 8;
|
||||
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
||||
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TM_stride = kFragSize * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TN_stride = kFragSize * WN;
|
||||
|
||||
// Warp tile size along M
|
||||
STEEL_CONST short TM = BM / TM_stride;
|
||||
// Warp tile size along N
|
||||
STEEL_CONST short TN = BN / TN_stride;
|
||||
|
||||
// Threadgroup A strides
|
||||
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
|
||||
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
|
||||
|
||||
// Threadgroup B strides
|
||||
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
|
||||
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
|
||||
|
||||
// Threadgroup strides along K
|
||||
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
|
||||
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
|
||||
|
||||
// Simdgroup matrices
|
||||
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
|
||||
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
|
||||
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
|
||||
|
||||
// Offsets within threadgroup
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
short As_offset;
|
||||
short Bs_offset;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockMMA(
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
// Determine thread position in simdgroup matrix
|
||||
short tm = kFragSize * (simd_group_id / WN);
|
||||
short tn = kFragSize * (simd_group_id % WN);
|
||||
|
||||
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
||||
sm = simd_coord.y;
|
||||
sn = simd_coord.x;
|
||||
|
||||
// Determine thread and simdgroup offset
|
||||
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
|
||||
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
|
||||
|
||||
sm += tm;
|
||||
sn += tn;
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||
// Adjust for simdgroup and thread location
|
||||
As += As_offset;
|
||||
Bs += Bs_offset;
|
||||
|
||||
// Iterate over BK in blocks of kFragSize
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short kk = 0; kk < BK; kk += kFragSize) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
tile_matmad(Ctile, Atile, Btile, Ctile);
|
||||
|
||||
// Progress to next simdgroup tile
|
||||
As += tile_stride_a;
|
||||
Bs += tile_stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device U* D, const int ldd) {
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
||||
}
|
||||
|
||||
// Adjust for simdgroup and thread location
|
||||
D += sm * ldd + sn;
|
||||
|
||||
Ctile.template store<U, WM, WN>(D, ldd);
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
||||
}
|
||||
|
||||
// Adjust for simdgroup and thread location
|
||||
D += sm * ldd + sn;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename UnaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename BinaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue(
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
|
||||
accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename BinaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue_safe(
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
||||
|
||||
// Read C
|
||||
U c_elems[kelems] = {0};
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
||||
c_elems[k] = C[offset_c + k * fdc];
|
||||
}
|
||||
}
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
D += (sm)*ldd + sn;
|
||||
|
||||
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_result_safe(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
D += (sm)*ldd + sn;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
||||
D[offset_d + k] =
|
||||
epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
36
mlx/backend/metal/kernels/steel/attn/params.h
Normal file
36
mlx/backend/metal/kernels/steel/attn/params.h
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Attn param classes
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
struct AttnParams {
|
||||
int B; ///< Batch Size
|
||||
int H; ///< Heads
|
||||
int D; ///< Head Dim
|
||||
|
||||
int qL; ///< Query Sequence Length
|
||||
int kL; ///< Key Sequence Length
|
||||
|
||||
int gqa_factor; ///< Group Query factor
|
||||
float scale; ///< Attention scale
|
||||
|
||||
int NQ; ///< Number of query blocks
|
||||
int NK; ///< Number of key/value blocks
|
||||
|
||||
int NQ_aligned; ///< Number of full query blocks
|
||||
int NK_aligned; ///< Number of full key/value blocks
|
||||
|
||||
size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
|
||||
size_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
|
||||
size_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
|
||||
size_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
71
mlx/backend/metal/kernels/steel/attn/transforms.h
Normal file
71
mlx/backend/metal/kernels/steel/attn/transforms.h
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Transforms and Epilogues
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformNone {
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x, OutT) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformAdd {
|
||||
TransformAdd(const float, const float) {}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x, OutT c) {
|
||||
return static_cast<OutT>(x) + c;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformAxpby {
|
||||
const float alpha;
|
||||
const float beta;
|
||||
|
||||
TransformAxpby(const float alpha_, const float beta_)
|
||||
: alpha(alpha_), beta(beta_) {}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
|
||||
METAL_FUNC OutT apply(InT x, OutT c) const {
|
||||
return static_cast<OutT>(x * alpha + (beta * c));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AccumHelper {
|
||||
typedef float accum_type;
|
||||
};
|
||||
|
||||
struct BlockSwizzle {
|
||||
static METAL_FUNC int2
|
||||
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
|
||||
const int tid_x = (tid.x) >> swizzle_log;
|
||||
const int tid_y =
|
||||
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
|
||||
return int2(tid_x, tid_y);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
@@ -5,7 +5,7 @@
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"
|
||||
|
@@ -142,8 +142,8 @@ implicit_gemm_conv_2d_general(
|
||||
// Store results to device memory
|
||||
{
|
||||
// Adjust for simdgroup and thread locatio
|
||||
int offset_m = c_row + mma_op.sm + mma_op.tm;
|
||||
int offset_n = c_col + mma_op.sn + mma_op.tn;
|
||||
int offset_m = c_row + mma_op.sm;
|
||||
int offset_n = c_col + mma_op.sn;
|
||||
C += offset_n;
|
||||
|
||||
if (offset_n >= gemm_params->N)
|
||||
@@ -169,17 +169,17 @@ implicit_gemm_conv_2d_general(
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < mma_t::TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum =
|
||||
mma_op.results[i * mma_t::TN + j].thread_elements();
|
||||
thread const auto& accum = mma_op.Ctile.frag_at(i, j);
|
||||
int offset = offset_cm + (j * mma_t::TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * mma_t::TN_stride < diff) {
|
||||
C[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag;
|
||||
|
||||
if (j * mma_t::TN_stride + 1 < diff) {
|
||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||
// Apply epilogue and output C
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * mma_t::TN_stride + k) < diff) {
|
||||
C[offset + k] = Epilogue::apply(accum[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -5,7 +5,7 @@
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
@@ -36,11 +35,11 @@
|
||||
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
|
||||
|
@@ -1,7 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h"
|
||||
|
||||
|
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
@@ -18,6 +19,347 @@ using namespace metal;
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <typename T, int kFragRows_, int kFragCols_>
|
||||
struct BaseMMAFrag {
|
||||
static_assert(
|
||||
kFragRows_ == 8,
|
||||
"Only 8 x 8 fragment matrices are currently supported");
|
||||
static_assert(
|
||||
kFragCols_ == 8,
|
||||
"Only 8 x 8 fragment matrices are currently supported");
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BaseMMAFrag<T, 8, 8> {
|
||||
STEEL_CONST int kFragRows = 8;
|
||||
STEEL_CONST int kFragCols = 8;
|
||||
|
||||
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
|
||||
|
||||
STEEL_CONST int kElemRows = 1;
|
||||
STEEL_CONST int kElemCols = 2;
|
||||
|
||||
static_assert(
|
||||
kElemRows * kElemCols == kElemsPerFrag,
|
||||
"MMAFrag shape is not consistent with MMAFrag size");
|
||||
|
||||
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
|
||||
typedef metal::vec<T, kElemsPerFrag> frag_type;
|
||||
|
||||
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
|
||||
[[thread_index_in_simdgroup]]) {
|
||||
const short qid = simd_lane_id / 4;
|
||||
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
|
||||
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
return short2{fn, fm};
|
||||
}
|
||||
|
||||
template <typename SrcPtrType, typename StrX, typename StrY>
|
||||
METAL_FUNC static constexpr void
|
||||
load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename SrcPtrType,
|
||||
typename StrX,
|
||||
typename StrY,
|
||||
typename LimX,
|
||||
typename LimY,
|
||||
typename OffX,
|
||||
typename OffY>
|
||||
METAL_FUNC static constexpr void load_safe(
|
||||
thread frag_type& dst,
|
||||
SrcPtrType src,
|
||||
StrX str_x,
|
||||
StrY str_y,
|
||||
LimX lim_x,
|
||||
LimY lim_y,
|
||||
OffX off_x = Int<0>{},
|
||||
OffY off_y = Int<0>{}) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||
dst[i * kElemCols + j] =
|
||||
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
|
||||
} else {
|
||||
dst[i * kElemCols + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstPtrType, typename StrX, typename StrY>
|
||||
METAL_FUNC static constexpr void
|
||||
store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
|
||||
using U = pointer_element_t<DstPtrType>;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename DstPtrType,
|
||||
typename StrX,
|
||||
typename StrY,
|
||||
typename LimX,
|
||||
typename LimY,
|
||||
typename OffX,
|
||||
typename OffY>
|
||||
METAL_FUNC static constexpr void store_safe(
|
||||
const thread frag_type& src,
|
||||
DstPtrType dst,
|
||||
StrX str_x,
|
||||
StrY str_y,
|
||||
LimX lim_x,
|
||||
LimY lim_y,
|
||||
OffX off_x = Int<0>{},
|
||||
OffY off_y = Int<0>{}) {
|
||||
using U = pointer_element_t<DstPtrType>;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
|
||||
static_cast<U>(src[i * kElemCols + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC static constexpr void mma(
|
||||
thread frag_type& D,
|
||||
thread frag_type& A,
|
||||
thread frag_type& B,
|
||||
thread frag_type& C) {
|
||||
mat_type D_mat;
|
||||
mat_type A_mat;
|
||||
mat_type B_mat;
|
||||
mat_type C_mat;
|
||||
|
||||
reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
|
||||
reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
|
||||
reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
|
||||
|
||||
mma(D_mat, A_mat, B_mat, C_mat);
|
||||
|
||||
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
|
||||
}
|
||||
|
||||
METAL_FUNC static constexpr void mma(
|
||||
thread mat_type& D,
|
||||
thread mat_type& A,
|
||||
thread mat_type& B,
|
||||
thread mat_type& C) {
|
||||
simdgroup_multiply_accumulate(D, A, B, C);
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int kTileRows_,
|
||||
int kTileCols_,
|
||||
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
|
||||
struct MMATile {
|
||||
using MMAFrag_t = MMAFrag_;
|
||||
using elem_type = T;
|
||||
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
|
||||
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
|
||||
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
|
||||
|
||||
STEEL_CONST int kTileRows = kTileRows_;
|
||||
STEEL_CONST int kTileCols = kTileCols_;
|
||||
|
||||
STEEL_CONST int kRows = kTileRows * kFragRows;
|
||||
STEEL_CONST int kCols = kTileCols * kFragCols;
|
||||
|
||||
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
|
||||
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
|
||||
|
||||
typedef typename MMAFrag_t::mat_type mat_type;
|
||||
typedef typename MMAFrag_t::frag_type frag_type;
|
||||
|
||||
frag_type val_frags[kNumFrags] = {frag_type(0)};
|
||||
|
||||
METAL_FUNC MMATile() thread {}
|
||||
|
||||
METAL_FUNC constexpr void clear() {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kNumFrags; ++i) {
|
||||
val_frags[i] = frag_type(0);
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
|
||||
return val_frags[i * kTileCols + j];
|
||||
}
|
||||
|
||||
METAL_FUNC constexpr const thread frag_type& frag_at(
|
||||
const short i,
|
||||
const short j) const {
|
||||
return val_frags[i * kTileCols + j];
|
||||
}
|
||||
|
||||
METAL_FUNC mat_type mat_at(const short i, const short j) {
|
||||
mat_type val_mat;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
|
||||
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
|
||||
}
|
||||
return val_mat;
|
||||
}
|
||||
|
||||
METAL_FUNC thread elem_type* elems() {
|
||||
return reinterpret_cast<thread elem_type*>(val_frags);
|
||||
}
|
||||
|
||||
METAL_FUNC const thread elem_type* elems() const {
|
||||
return reinterpret_cast<const thread elem_type*>(val_frags);
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
||||
METAL_FUNC void load(const threadgroup U* src) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::load(
|
||||
frag_at(i, j),
|
||||
&(
|
||||
src[(i * kFragRows) * w_x * str_x +
|
||||
(j * kFragCols) * w_y * str_y]),
|
||||
Int<str_x>{},
|
||||
Int<str_y>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
||||
METAL_FUNC void store(threadgroup U* dst) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store(
|
||||
frag_at(i, j),
|
||||
&(
|
||||
dst[(i * kFragRows) * w_x * str_x +
|
||||
(j * kFragCols) * w_y * str_y]),
|
||||
Int<str_x>{},
|
||||
Int<str_y>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void load(const device U* src, const int ld) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::load(
|
||||
frag_at(i, j),
|
||||
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
||||
ld,
|
||||
Int<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void store(device U* dst, const int ld) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store(
|
||||
frag_at(i, j),
|
||||
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
||||
ld,
|
||||
Int<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void
|
||||
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::load_safe(
|
||||
frag_at(i, j),
|
||||
src,
|
||||
ld,
|
||||
Int<1>{},
|
||||
src_tile_dims.y,
|
||||
src_tile_dims.x,
|
||||
(i * kFragRows) * w_x,
|
||||
(j * kFragCols) * w_y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void
|
||||
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store_safe(
|
||||
frag_at(i, j),
|
||||
dst,
|
||||
ld,
|
||||
Int<1>{},
|
||||
dst_tile_dims.y,
|
||||
dst_tile_dims.x,
|
||||
(i * kFragRows) * w_x,
|
||||
(j * kFragCols) * w_y);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int M, int N, int K>
|
||||
METAL_FUNC void tile_matmad(
|
||||
thread MMATile<T, M, N>& D,
|
||||
thread MMATile<U, M, K>& A,
|
||||
thread MMATile<U, K, N>& B,
|
||||
thread MMATile<T, M, N>& C) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short m = 0; m < M; ++m) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short n = 0; n < N; ++n) {
|
||||
short n_serp = (m % 2) ? (N - 1 - n) : n;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < K; ++k) {
|
||||
MMATile<T, M, N>::MMAFrag_t::mma(
|
||||
D.frag_at(m, n_serp),
|
||||
A.frag_at(m, k),
|
||||
B.frag_at(k, n_serp),
|
||||
C.frag_at(m, n_serp));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
@@ -33,39 +375,38 @@ template <
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct BlockMMA {
|
||||
// MMAFrag size
|
||||
STEEL_CONST short kFragSize = 8;
|
||||
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
||||
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TM_stride = 8 * WM;
|
||||
STEEL_CONST short TM_stride = kFragSize * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TN_stride = 8 * WN;
|
||||
STEEL_CONST short TN_stride = kFragSize * WN;
|
||||
|
||||
// Warp tile size along M
|
||||
STEEL_CONST short TM = BM / TM_stride;
|
||||
STEEL_CONST short TM = BM / (kFragSize * WM);
|
||||
// Warp tile size along N
|
||||
STEEL_CONST short TN = BN / TN_stride;
|
||||
STEEL_CONST short TN = BN / (kFragSize * WN);
|
||||
|
||||
// Strides of A, B along reduction axis
|
||||
STEEL_CONST short simd_stride_a = {
|
||||
transpose_a ? TM_stride : TM_stride * lda_tgp};
|
||||
STEEL_CONST short simd_stride_b = {
|
||||
transpose_b ? TN_stride * ldb_tgp : TN_stride};
|
||||
// Threadgroup A strides
|
||||
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
|
||||
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
|
||||
|
||||
// Jump between elements
|
||||
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
|
||||
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
|
||||
// Threadgroup B strides
|
||||
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
|
||||
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
|
||||
|
||||
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
|
||||
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
|
||||
// Threadgroup strides along K
|
||||
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
|
||||
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
|
||||
|
||||
// Simdgroup matrices
|
||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
|
||||
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
|
||||
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
|
||||
|
||||
// Offsets within threadgroup
|
||||
const short tm;
|
||||
const short tn;
|
||||
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
@@ -75,18 +416,21 @@ struct BlockMMA {
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockMMA(
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
// Determine thread position in simdgroup matrix
|
||||
short qid = simd_lane_id / 4;
|
||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
short tm = kFragSize * (simd_group_id / WN);
|
||||
short tn = kFragSize * (simd_group_id % WN);
|
||||
|
||||
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
||||
sm = simd_coord.y;
|
||||
sn = simd_coord.x;
|
||||
|
||||
// Determine thread and simdgroup offset
|
||||
As_offset =
|
||||
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
|
||||
Bs_offset =
|
||||
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
|
||||
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
|
||||
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
|
||||
|
||||
sm += tm;
|
||||
sn += tn;
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
@@ -95,47 +439,20 @@ struct BlockMMA {
|
||||
As += As_offset;
|
||||
Bs += Bs_offset;
|
||||
|
||||
// Iterate over BK in blocks of 8
|
||||
// Iterate over BK in blocks of kFragSize
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short kk = 0; kk < BK; kk += 8) {
|
||||
for (short kk = 0; kk < BK; kk += kFragSize) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load elements from threadgroup A as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
Asimd[i].thread_elements()[0] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + 0]);
|
||||
Asimd[i].thread_elements()[1] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
|
||||
}
|
||||
Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load elements from threadgroup B as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
Bsimd[j].thread_elements()[0] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
|
||||
Bsimd[j].thread_elements()[1] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
|
||||
}
|
||||
Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
short j_serp = (i % 2) ? (TN - 1 - j) : j;
|
||||
|
||||
simdgroup_multiply_accumulate(
|
||||
results[i * TN + j_serp],
|
||||
Asimd[i],
|
||||
Bsimd[j_serp],
|
||||
results[i * TN + j_serp]);
|
||||
}
|
||||
}
|
||||
tile_matmad(Ctile, Atile, Btile, Ctile);
|
||||
|
||||
// Progress to next simdgroup tile
|
||||
As += tile_stride_a;
|
||||
@@ -144,58 +461,35 @@ struct BlockMMA {
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device U* D, const int ldd) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
METAL_FUNC void store_result(device U* D, const int ldd) {
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
|
||||
|
||||
// Write out D
|
||||
D[offset] = outs[0];
|
||||
D[offset + 1] = outs[1];
|
||||
}
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
||||
}
|
||||
|
||||
// Adjust for simdgroup and thread location
|
||||
D += sm * ldd + sn;
|
||||
|
||||
Ctile.template store<U, WM, WN>(D, ldd);
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const {
|
||||
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
||||
}
|
||||
|
||||
// Adjust for simdgroup and thread location
|
||||
D += (sm + tm) * ldd + (tn + sn);
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
D += sm * ldd + sn;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
D[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
D[offset + 1] = Epilogue::apply(accum[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
@@ -203,16 +497,8 @@ struct BlockMMA {
|
||||
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = results[i * TN + j].thread_elements();
|
||||
|
||||
// Apply epilogue
|
||||
accum[0] = epilogue_op.apply(accum[0]);
|
||||
accum[1] = epilogue_op.apply(accum[1]);
|
||||
}
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,7 +510,7 @@ struct BlockMMA {
|
||||
const int fdc,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
@@ -232,12 +518,14 @@ struct BlockMMA {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = results[i * TN + j].thread_elements();
|
||||
thread auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
// Apply epilogue
|
||||
accum[0] = epilogue_op.apply(accum[0], C[offset_c]);
|
||||
accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
|
||||
accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -251,8 +539,8 @@ struct BlockMMA {
|
||||
short2 dst_tile_dims,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
@@ -263,22 +551,26 @@ struct BlockMMA {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = results[i * TN + j].thread_elements();
|
||||
thread auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
// Read C
|
||||
U c_elems[2] = {0};
|
||||
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
||||
|
||||
if ((j * TN_stride + 1) < dst_tile_dims.x) {
|
||||
c_elems[0] = C[offset_c];
|
||||
c_elems[1] = C[offset_c + fdc];
|
||||
} else if ((j * TN_stride) < dst_tile_dims.x) {
|
||||
c_elems[0] = C[offset_c];
|
||||
// Read C
|
||||
U c_elems[kelems] = {0};
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
||||
c_elems[k] = C[offset_c + k * fdc];
|
||||
}
|
||||
}
|
||||
|
||||
// Apply epilogue
|
||||
accum[0] = epilogue_op.apply(accum[0], c_elems[0]);
|
||||
accum[1] = epilogue_op.apply(accum[1], c_elems[1]);
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -292,8 +584,10 @@ struct BlockMMA {
|
||||
const int fdc,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
D += (sm)*ldd + sn;
|
||||
|
||||
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
@@ -301,18 +595,15 @@ struct BlockMMA {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
thread const auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
U outs[2] = {
|
||||
epilogue_op.apply(accum[0], C[offset_c]),
|
||||
epilogue_op.apply(accum[1], C[offset_c + fdc])};
|
||||
|
||||
// Write out D
|
||||
D[offset_d] = outs[0];
|
||||
D[offset_d + 1] = outs[1];
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -326,30 +617,32 @@ struct BlockMMA {
|
||||
short2 dst_tile_dims,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
D += (sm)*ldd + sn;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
thread const auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
||||
D[offset_d + k] =
|
||||
epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
96
mlx/backend/metal/kernels/steel/utils/integral_constant.h
Normal file
96
mlx/backend/metal/kernels/steel/utils/integral_constant.h
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include "mlx/backend/metal/kernels/steel/utils/type_traits.h"
|
||||
|
||||
#pragma METAL internals : enable
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Integral constant with casting
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, T v>
|
||||
struct integral_constant {
|
||||
static constexpr constant T value = v;
|
||||
using value_type = T;
|
||||
using type = integral_constant;
|
||||
|
||||
METAL_FUNC constexpr operator value_type() const noexcept {
|
||||
return value;
|
||||
}
|
||||
|
||||
// METAL_FUNC constexpr value_type operator()() const noexcept {
|
||||
// return value;
|
||||
// }
|
||||
};
|
||||
|
||||
template <bool B>
|
||||
using bool_constant = integral_constant<bool, B>;
|
||||
using true_type = bool_constant<true>;
|
||||
using false_type = bool_constant<false>;
|
||||
|
||||
template <class T>
|
||||
struct is_integral : bool_constant<metal::is_integral<T>::value> {};
|
||||
|
||||
template <class T, T v>
|
||||
struct is_integral<integral_constant<T, v>>
|
||||
: bool_constant<metal::is_integral<T>::value> {};
|
||||
|
||||
template <typename T>
|
||||
constexpr constant bool is_integral_v = is_integral<T>::value;
|
||||
|
||||
template <int val>
|
||||
using Int = integral_constant<int, val>;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Binary Operators on Integral constants
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define integral_const_binop(__op__, __operator__) \
|
||||
template <typename T, T tv, typename U, U uv> \
|
||||
METAL_FUNC constexpr auto __operator__( \
|
||||
integral_constant<T, tv>, integral_constant<U, uv>) { \
|
||||
constexpr auto res = tv __op__ uv; \
|
||||
return integral_constant<decltype(res), res>{}; \
|
||||
}
|
||||
|
||||
integral_const_binop(+, operator+);
|
||||
integral_const_binop(-, operator-);
|
||||
integral_const_binop(*, operator*);
|
||||
integral_const_binop(/, operator/);
|
||||
|
||||
integral_const_binop(==, operator==);
|
||||
integral_const_binop(!=, operator!=);
|
||||
integral_const_binop(<, operator<);
|
||||
integral_const_binop(>, operator>);
|
||||
integral_const_binop(<=, operator<=);
|
||||
integral_const_binop(>=, operator>=);
|
||||
|
||||
integral_const_binop(&&, operator&&);
|
||||
integral_const_binop(||, operator||);
|
||||
|
||||
#undef integral_const_binop
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Reduction operators
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC constexpr T sum(T x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T, typename... Us>
|
||||
METAL_FUNC constexpr auto sum(T x, Us... us) {
|
||||
return x + sum(us...);
|
||||
}
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
|
||||
#pragma METAL internals : disable
|
55
mlx/backend/metal/kernels/steel/utils/type_traits.h
Normal file
55
mlx/backend/metal/kernels/steel/utils/type_traits.h
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#pragma METAL internals : enable
|
||||
|
||||
namespace metal {
|
||||
|
||||
template <typename T>
|
||||
struct is_empty : metal::bool_constant<__is_empty(T)> {};
|
||||
|
||||
#ifdef __cpp_variable_templates
|
||||
template <typename T>
|
||||
constexpr constant bool is_empty_v = is_empty<T>::value;
|
||||
#endif
|
||||
|
||||
template <typename... Ts>
|
||||
struct make_void {
|
||||
typedef void type;
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
using void_t = typename make_void<Ts...>::type;
|
||||
|
||||
template <class T>
|
||||
struct is_static : metal::bool_constant<is_empty<remove_cv_t<T>>::value> {};
|
||||
|
||||
template <typename T>
|
||||
struct pointer_element {};
|
||||
|
||||
template <typename T>
|
||||
struct pointer_element<thread T*> {
|
||||
using type = remove_cv_t<T>;
|
||||
};
|
||||
template <typename T>
|
||||
struct pointer_element<device T*> {
|
||||
using type = remove_cv_t<T>;
|
||||
};
|
||||
template <typename T>
|
||||
struct pointer_element<constant T*> {
|
||||
using type = remove_cv_t<T>;
|
||||
};
|
||||
template <typename T>
|
||||
struct pointer_element<threadgroup T*> {
|
||||
using type = remove_cv_t<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using pointer_element_t = typename pointer_element<remove_cv_t<T>>::type;
|
||||
|
||||
} // namespace metal
|
||||
|
||||
#pragma METAL internals : disable
|
@@ -32,13 +32,13 @@ template <typename T, typename Op>
|
||||
constant const size_t& b_strides,
|
||||
constant const size_t& c_strides,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_strides);
|
||||
auto b_idx = elem_to_loc_1(index, b_strides);
|
||||
auto c_idx = elem_to_loc_1(index, c_strides);
|
||||
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_1<size_t, uint>(index, c_strides);
|
||||
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
template <typename T, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void ternary_g_nd2(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
@@ -49,14 +49,14 @@ template <typename T, typename Op>
|
||||
constant const size_t c_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
auto c_idx = elem_to_loc_2(index, c_strides);
|
||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_2<size_t, IdxT>(index, c_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
template <typename T, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void ternary_g_nd3(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
@@ -67,15 +67,14 @@ template <typename T, typename Op>
|
||||
constant const size_t c_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
auto c_idx = elem_to_loc_3(index, c_strides);
|
||||
size_t out_idx =
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_3<size_t, IdxT>(index, c_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, int N = 1>
|
||||
template <typename T, typename Op, int N = 1, typename IdxT = size_t>
|
||||
[[kernel]] void ternary_g(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
@@ -88,7 +87,7 @@ template <typename T, typename Op, int N = 1>
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_3_nd(
|
||||
auto idx = elem_to_loc_3_nd<IdxT>(
|
||||
{N * index.x, index.y, index.z},
|
||||
shape,
|
||||
a_strides,
|
||||
@@ -96,11 +95,10 @@ template <typename T, typename Op, int N = 1>
|
||||
c_strides,
|
||||
ndim);
|
||||
auto xshape = shape[ndim - 1];
|
||||
size_t out_idx =
|
||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_xstride = a_strides[ndim - 1];
|
||||
auto b_xstride = b_strides[ndim - 1];
|
||||
auto c_xstride = c_strides[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
IdxT a_xstride = a_strides[ndim - 1];
|
||||
IdxT b_xstride = b_strides[ndim - 1];
|
||||
IdxT c_xstride = c_strides[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||
idx.x += a_xstride;
|
||||
|
@@ -4,18 +4,20 @@
|
||||
#include <metal_math>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/ternary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
|
||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op)
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
|
||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
|
||||
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
|
||||
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
||||
|
||||
#define instantiate_ternary_types(op) \
|
||||
instantiate_ternary_all(op, bool_, bool) \
|
||||
|
@@ -18,7 +18,12 @@ template <typename T, typename U, typename Op>
|
||||
out[offset] = Op()(in[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N = 1>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = size_t>
|
||||
[[kernel]] void unary_g(
|
||||
device const T* in,
|
||||
device U* out,
|
||||
@@ -27,12 +32,11 @@ template <typename T, typename U, typename Op, int N = 1>
|
||||
device const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx =
|
||||
elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
|
||||
auto idx = elem_to_loc<size_t, IdxT>(
|
||||
{N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
|
||||
auto xshape = in_shape[ndim - 1];
|
||||
auto xstride = in_strides[ndim - 1];
|
||||
size_t out_idx =
|
||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||
IdxT xstride = in_strides[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
out[out_idx++] = Op()(in[idx]);
|
||||
idx += xstride;
|
||||
|
@@ -5,11 +5,13 @@
|
||||
#include "mlx/backend/metal/kernels/unary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
|
||||
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
||||
instantiate_kernel("gn4_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||
|
||||
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
|
||||
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
||||
instantiate_kernel( \
|
||||
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, uint) \
|
||||
instantiate_kernel( \
|
||||
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||
|
||||
#define instantiate_unary_all_same(op, tname, type) \
|
||||
instantiate_unary_all(op, tname, tname, type, type)
|
||||
|
@@ -3,7 +3,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_math>
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
// The correct bf16.h is included based on the metal version
|
||||
// by giving the correct path to -I during compilation
|
||||
// e.g. mlx/backend/metal/kernels/metal_3_0/ for Metal 3.0
|
||||
#include "bf16.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16_math.h"
|
||||
#include "mlx/backend/metal/kernels/complex.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
|
||||
@@ -83,44 +89,45 @@ struct Limits<complex64_t> {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Single Array with generic dims
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const stride_t* strides,
|
||||
constant const StrideT* strides,
|
||||
int ndim) {
|
||||
stride_t loc = 0;
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
stride_t elem,
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
StrideT elem,
|
||||
constant const int* shape,
|
||||
constant const stride_t* strides,
|
||||
constant const StrideT* strides,
|
||||
int ndim) {
|
||||
stride_t loc = 0;
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
// Non templated version to handle arbitrary dims
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const stride_t* strides,
|
||||
constant const StrideT* strides,
|
||||
int ndim) {
|
||||
stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
||||
IdxT loc =
|
||||
elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
loc += (elem.z % shape[d]) * IdxT(strides[d]);
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
@@ -129,61 +136,65 @@ METAL_FUNC stride_t elem_to_loc(
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Single Array with fixed N dims
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
|
||||
return elem * stride;
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT& stride) {
|
||||
return elem * IdxT(stride);
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t
|
||||
elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
|
||||
return elem.x * strides[1] + elem.y * strides[0];
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2]) {
|
||||
return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t
|
||||
elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
|
||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) {
|
||||
return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
|
||||
elem.z * IdxT(strides[0]);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Multiple Arrays with generic dims
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC ulong2 elem_to_loc_2_nd(
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const stride_t* a_strides,
|
||||
constant const stride_t* b_strides,
|
||||
constant const StrideT* a_strides,
|
||||
constant const StrideT* b_strides,
|
||||
int ndim) {
|
||||
ulong2 loc = {
|
||||
ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||
ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
||||
vec<IdxT, 2> loc = {
|
||||
IdxT(
|
||||
elem.x * IdxT(a_strides[ndim - 1]) +
|
||||
IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
|
||||
IdxT(
|
||||
elem.x * IdxT(b_strides[ndim - 1]) +
|
||||
elem.y * IdxT(b_strides[ndim - 2]))};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
loc.x += l * IdxT(a_strides[d]);
|
||||
loc.y += l * IdxT(b_strides[d]);
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
METAL_FUNC ulong3 elem_to_loc_3_nd(
|
||||
template <typename IdxT = size_t>
|
||||
METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
int ndim) {
|
||||
ulong3 loc = {
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2],
|
||||
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]};
|
||||
vec<IdxT, 3> loc = {
|
||||
elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
|
||||
elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
|
||||
elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
loc.z += l * c_strides[d];
|
||||
loc.x += l * IdxT(a_strides[d]);
|
||||
loc.y += l * IdxT(b_strides[d]);
|
||||
loc.z += l * IdxT(c_strides[d]);
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
@@ -193,16 +204,21 @@ METAL_FUNC ulong3 elem_to_loc_3_nd(
|
||||
// Elem to loc in a loop utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int dim, typename offset_t = size_t>
|
||||
struct looped_elem_to_loc {
|
||||
looped_elem_to_loc<dim - 1, offset_t> inner_looper;
|
||||
offset_t offset{0};
|
||||
template <int DIM, typename OffsetT = size_t, bool General = true>
|
||||
struct LoopedElemToLoc {
|
||||
int dim;
|
||||
LoopedElemToLoc<DIM - 1, OffsetT, General> inner_looper;
|
||||
OffsetT offset{0};
|
||||
int index{0};
|
||||
|
||||
void next(const constant int* shape, const constant size_t* strides) {
|
||||
index++;
|
||||
offset += strides[dim - 1];
|
||||
LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
|
||||
|
||||
void next(const constant int* shape, const constant size_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index++;
|
||||
offset += OffsetT(strides[dim - 1]);
|
||||
if (index >= shape[dim - 1]) {
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
@@ -211,13 +227,21 @@ struct looped_elem_to_loc {
|
||||
}
|
||||
|
||||
void next(int n, const constant int* shape, const constant size_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index += n;
|
||||
offset += n * strides[dim - 1];
|
||||
offset += n * OffsetT(strides[dim - 1]);
|
||||
|
||||
if (index >= shape[dim - 1]) {
|
||||
int extra = index - shape[dim - 1];
|
||||
if (extra >= shape[dim - 1]) {
|
||||
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
|
||||
extra = extra % shape[dim - 1];
|
||||
} else {
|
||||
inner_looper.next(shape, strides);
|
||||
}
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
offset = inner_looper.offset;
|
||||
if (extra > 0) {
|
||||
next(extra, shape, strides);
|
||||
@@ -225,44 +249,61 @@ struct looped_elem_to_loc {
|
||||
}
|
||||
}
|
||||
|
||||
offset_t
|
||||
location(offset_t, const constant int*, const constant size_t*, int) {
|
||||
OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename offset_t>
|
||||
struct looped_elem_to_loc<1, offset_t> {
|
||||
offset_t offset{0};
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, OffsetT, true> {
|
||||
int dim;
|
||||
OffsetT offset{0};
|
||||
uint index{0};
|
||||
|
||||
LoopedElemToLoc(int dim) : dim(dim) {}
|
||||
|
||||
void next(const constant int* shape, const constant size_t* strides) {
|
||||
index++;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
void next(int n, const constant int* shape, const constant size_t* strides) {
|
||||
index += n;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset = index * OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, OffsetT, false> {
|
||||
OffsetT offset{0};
|
||||
|
||||
LoopedElemToLoc(int) {}
|
||||
|
||||
void next(const constant int*, const constant size_t* strides) {
|
||||
offset += strides[0];
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
void next(int n, const constant int*, const constant size_t* strides) {
|
||||
offset += n * strides[0];
|
||||
offset += n * OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
offset_t
|
||||
location(offset_t, const constant int*, const constant size_t*, int) {
|
||||
OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename offset_t>
|
||||
struct looped_elem_to_loc<0, offset_t> {
|
||||
void next(const constant int*, const constant size_t*) {}
|
||||
void next(int, const constant int*, const constant size_t*) {}
|
||||
|
||||
offset_t location(
|
||||
offset_t idx,
|
||||
const constant int* shape,
|
||||
const constant size_t* strides,
|
||||
int ndim) {
|
||||
return elem_to_loc(idx, shape, strides, ndim);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Calculation utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
@@ -11,12 +11,12 @@ SRC_DIR=$3
|
||||
SRC_FILE=$4
|
||||
CFLAGS=$5
|
||||
SRC_NAME=$(basename -- "${SRC_FILE}")
|
||||
JIT_INCLUDES=${SRC_DIR}/mlx/backend/metal/kernels/jit
|
||||
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
|
||||
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
CONTENT=$($CC -I "$SRC_DIR" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
||||
CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
namespace mlx::core::metal {
|
||||
|
@@ -88,6 +88,83 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
// Steel matmul fallback
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define GEMM_TPARAM_MACRO(devc) \
|
||||
if (devc == 'g') { /* Small device */ \
|
||||
if (!transpose_a && transpose_b) { /* nt */ \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 32; \
|
||||
wm = 2; \
|
||||
wn = 2; \
|
||||
} else if (out.dtype() != float32) { /* half and bfloat */ \
|
||||
bm = 64; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 1; \
|
||||
wn = 2; \
|
||||
} \
|
||||
} else if (devc == 'd') { /* Large device */ \
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) { /* large matmul */ \
|
||||
if (out.dtype() != float32) { /* half and bfloat */ \
|
||||
if (2 * std::max(M, N) > K) { /* Reasonable K */ \
|
||||
bm = 64; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 1; \
|
||||
wn = 2; \
|
||||
} else if (!transpose_a && transpose_b) { /* nt with large k */ \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 32; \
|
||||
wm = 2; \
|
||||
wn = 2; \
|
||||
} else { /* nn with large K */ \
|
||||
bm = 32; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 1; \
|
||||
wn = 2; \
|
||||
} \
|
||||
} /* float takes default */ \
|
||||
} else { /* smaller matmul */ \
|
||||
if (out.dtype() != float32) { /* half and bfloat */ \
|
||||
if (!transpose_a && transpose_b) { /* nt */ \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 32; \
|
||||
wm = 2; \
|
||||
wn = 2; \
|
||||
} else { /* nn */ \
|
||||
bm = 64; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 1; \
|
||||
wn = 2; \
|
||||
} \
|
||||
} else { /* floats */ \
|
||||
if (!transpose_a && transpose_b) { /* nt */ \
|
||||
bm = 32; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 1; \
|
||||
wn = 2; \
|
||||
} else { /* nn */ \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 32; \
|
||||
wm = 2; \
|
||||
wn = 2; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} else { /* Medium device */ \
|
||||
bm = 64; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 2; \
|
||||
wn = 2; \
|
||||
}
|
||||
|
||||
void steel_matmul_regular(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -112,19 +189,11 @@ void steel_matmul_regular(
|
||||
using namespace mlx::steel;
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int bm = 64, bn = 64, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
char devc = d.get_architecture().back();
|
||||
GEMM_TPARAM_MACRO(devc)
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
@@ -180,7 +249,7 @@ void steel_matmul_regular(
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
@@ -219,12 +288,12 @@ void steel_matmul_regular(
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Record copies
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
@@ -321,7 +390,7 @@ void steel_matmul(
|
||||
wn,
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
@@ -347,34 +416,29 @@ void steel_matmul(
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(C_split, 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(params, 3);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Do accum kernel
|
||||
{
|
||||
auto c_split_buf =
|
||||
static_cast<const MTL::Resource*>(C_split.buffer().ptr());
|
||||
const class MTL::Resource* const resources[1] = {c_split_buf};
|
||||
compute_encoder->memoryBarrier(resources, 1);
|
||||
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||
type_to_name(C_split);
|
||||
|
||||
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
||||
d, kernel_name, C_split, out, false);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
compute_encoder.set_input_array(C_split, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
compute_encoder.set_bytes(split_k_partitions, 2);
|
||||
compute_encoder.set_bytes(split_k_partition_stride, 3);
|
||||
compute_encoder.set_bytes(N, 4);
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
auto group_dims = get_block_dims(N, M, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
@@ -556,7 +620,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
@@ -566,16 +630,16 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
@@ -753,7 +817,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
@@ -764,23 +828,23 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 7);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 8);
|
||||
compute_encoder.set_bytes(alpha_, 7);
|
||||
compute_encoder.set_bytes(beta_, 8);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
|
||||
set_vector_bytes(compute_encoder, C_batch_stride, 13);
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
||||
compute_encoder.set_vector_bytes(C_batch_stride, 13);
|
||||
|
||||
int bias_stride = c.strides()[c.ndim() - 1];
|
||||
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
|
||||
compute_encoder.set_bytes(bias_stride, 14);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
@@ -838,7 +902,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
@@ -864,8 +928,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(C_split, 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(params, 3);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Do accum kernel
|
||||
{
|
||||
@@ -874,25 +938,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
||||
d, kernel_name, C_split, out, true);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
compute_encoder.set_input_array(C_split, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
compute_encoder.set_bytes(split_k_partitions, 2);
|
||||
compute_encoder.set_bytes(split_k_partition_stride, 3);
|
||||
compute_encoder.set_bytes(N, 4);
|
||||
compute_encoder.set_input_array(c, 5);
|
||||
compute_encoder->setBytes(&ldc, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&fdc, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 8);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 9);
|
||||
compute_encoder.set_bytes(ldc, 6);
|
||||
compute_encoder.set_bytes(fdc, 7);
|
||||
compute_encoder.set_bytes(alpha_, 8);
|
||||
compute_encoder.set_bytes(beta_, 9);
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
auto group_dims = get_block_dims(N, M, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
@@ -903,19 +966,11 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Regular addmm dispatch
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int bm = 64, bn = 64, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
char devc = d.get_architecture().back();
|
||||
GEMM_TPARAM_MACRO(devc)
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
@@ -971,7 +1026,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
@@ -1022,13 +1077,13 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 5);
|
||||
compute_encoder.set_bytes(gemm_params, 4);
|
||||
compute_encoder.set_bytes(params, 5);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
@@ -1243,7 +1298,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
contiguous_kernel);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
@@ -1311,18 +1366,18 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
||||
|
||||
set_vector_bytes(compute_encoder, mask_strides, 23);
|
||||
set_vector_bytes(compute_encoder, mask_batch_strides, 24);
|
||||
compute_encoder.set_vector_bytes(mask_strides, 23);
|
||||
compute_encoder.set_vector_bytes(mask_batch_strides, 24);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
@@ -1362,7 +1417,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
wn,
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
@@ -1425,14 +1480,14 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
set_vector_bytes(compute_encoder, mask_strides, 13);
|
||||
compute_encoder.set_vector_bytes(mask_strides, 13);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
@@ -1626,7 +1681,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
@@ -1636,28 +1691,28 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 11);
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 11);
|
||||
|
||||
int batch_ndim_vec = batch_shape_vec.size();
|
||||
compute_encoder->setBytes(&batch_ndim_vec, sizeof(int), 12);
|
||||
set_vector_bytes(compute_encoder, batch_shape_vec, 13);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 14);
|
||||
compute_encoder.set_bytes(batch_ndim_vec, 12);
|
||||
compute_encoder.set_vector_bytes(batch_shape_vec, 13);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 14);
|
||||
|
||||
int batch_ndim_mat = batch_shape_mat.size();
|
||||
compute_encoder->setBytes(&batch_ndim_mat, sizeof(int), 15);
|
||||
set_vector_bytes(compute_encoder, batch_shape_mat, 16);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 17);
|
||||
compute_encoder.set_bytes(batch_ndim_mat, 15);
|
||||
compute_encoder.set_vector_bytes(batch_shape_mat, 16);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 17);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
|
||||
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
@@ -1667,19 +1722,11 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Regular kernel dispatch
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int bm = 64, bn = 64, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
char devc = d.get_architecture().back();
|
||||
GEMM_TPARAM_MACRO(devc)
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
@@ -1735,7 +1782,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
@@ -1774,10 +1821,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 10);
|
||||
compute_encoder.set_input_array(rhs_indices, 11);
|
||||
@@ -1792,11 +1839,11 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
operand_batch_ndim.push_back(0);
|
||||
|
||||
set_vector_bytes(compute_encoder, operand_shape, 13);
|
||||
set_vector_bytes(compute_encoder, operand_strides, 14);
|
||||
set_vector_bytes(compute_encoder, operand_batch_ndim, 15);
|
||||
compute_encoder.set_vector_bytes(operand_shape, 13);
|
||||
compute_encoder.set_vector_bytes(operand_strides, 14);
|
||||
compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
@@ -1,11 +1,11 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
@@ -13,20 +13,6 @@ bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
int max_ops_per_buffer() {
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
|
||||
return atoi(buff_str);
|
||||
} else {
|
||||
return 10;
|
||||
}
|
||||
};
|
||||
static int max_ops_per_buffer_ = get_val();
|
||||
return max_ops_per_buffer_;
|
||||
}
|
||||
|
||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||
|
||||
inline void check_error(MTL::CommandBuffer* cbuf) {
|
||||
if (cbuf->status() == MTL::CommandBufferStatusError) {
|
||||
std::ostringstream msg;
|
||||
@@ -77,7 +63,8 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
out.set_status(array::Status::evaluated);
|
||||
}
|
||||
|
||||
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
|
||||
if (signal ||
|
||||
d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) {
|
||||
d.end_encoding(s.index);
|
||||
if (signal) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
|
@@ -1,3 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
@@ -97,7 +98,9 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&) {
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const Dtype&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
@@ -106,8 +109,9 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const array&,
|
||||
const array&,
|
||||
const Dtype&,
|
||||
const Dtype&,
|
||||
const std::string&,
|
||||
int,
|
||||
int,
|
||||
int) {
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user