Compare commits

..

14 Commits

Author SHA1 Message Date
Ronan Collobert
87b680766e Gloo backend support 2024-11-13 13:52:37 -08:00
Ronan Collobert
70ffaa50d2 be more relaxed on OpenMPI version 2024-11-13 13:51:37 -08:00
Angelos Katharopoulos
d82699f0f1 Merge branch 'distributed-layers' into socket-distributed-layers 2024-11-05 11:36:16 -08:00
Angelos Katharopoulos
6fc00d2c10 Add rudimentary barrier 2024-11-05 11:34:55 -08:00
Angelos Katharopoulos
44f0de2854 Fix run without distributed 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
29ec3539ed TCP socket distributed 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
e94f0028c3 Change the send message size 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
e5354fcddb Make it work even for donated inputs 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
34dd079a64 Start a sockets based distributed backend 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
16975815e9 Fixes in distributed layers 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
a8b3da7946 Add distributed layers to nn top-level 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
060e1c9f92 Add quantized distributed layers 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
0b04742985 Add the distributed linear layers 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
c3ccd4919f Add MPI barrier 2024-11-05 11:26:53 -08:00
147 changed files with 4462 additions and 5798 deletions

View File

@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION) if(NOT MLX_VERSION)
set(MLX_VERSION 0.21.0) set(MLX_VERSION 0.19.3)
endif() endif()
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
@@ -89,27 +89,25 @@ elseif(MLX_BUILD_METAL)
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process( execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY) OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0) if(${MACOS_VERSION} LESS 14.0)
message( message(
FATAL_ERROR FATAL_ERROR
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON") "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
endif() endif()
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}") message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
set(METAL_CPP_URL set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
) )
# Get the metal version
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
execute_process( execute_process(
COMMAND COMMAND
zsh "-c" zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY) OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL}) FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp) FetchContent_MakeAvailable(metal_cpp)
@@ -117,6 +115,8 @@ elseif(MLX_BUILD_METAL)
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}> mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>) $<INSTALL_INTERFACE:include/metal_cpp>)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB}) target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
endif() endif()
if(MLX_BUILD_CPU) if(MLX_BUILD_CPU)
@@ -168,11 +168,12 @@ endif()
find_package(MPI) find_package(MPI)
if(MPI_FOUND) if(MPI_FOUND)
execute_process( execute_process(
COMMAND zsh "-c" "mpirun --version" COMMAND zsh "-c" "${MPIEXEC_EXECUTABLE} --version"
OUTPUT_VARIABLE MPI_VERSION OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET) ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*") if(${MPI_VERSION} MATCHES ".*Open MPI.*" OR ${MPI_VERSION} MATCHES ".*OpenRTE.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH}) target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
target_link_libraries(mlx PRIVATE ${MPI_CXX_LIBRARIES})
elseif(MPI_VERSION STREQUAL "") elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE) set(MPI_FOUND FALSE)
message( message(

View File

@@ -1,189 +1,62 @@
# Copyright © 2024 Apple Inc.
import argparse import argparse
import math import math
import os
import subprocess
import time
import mlx.core as mx import mlx.core as mx
import numpy as np from time_utils import time_fn
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) MAX_SEQ = 300
device_name = device_name.decode("utf-8").strip("\n") START_SEQ = 100
SEQ_INCREMENT = 50
N_warmup = 5
N_iter_bench = 40
N_iter_func = 8
def bench(f, *args): def time_self_attention_primitives():
for i in range(N_warmup): mx.random.seed(3)
f(*args) B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
s = time.perf_counter_ns() def sdpa_primitives(qs, ks, vs, alpha):
for i in range(N_iter_bench): s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
f(*args) p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
e = time.perf_counter_ns() o = p @ vs
return (e - s) * 1e-9 return o
time_fn(sdpa_primitives, q, k, v, scale)
def mlx_sdpa_fused_inner(q, k, v, scale): def time_self_attention_sdpa():
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None) mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def sdpa_fused(qs, ks, vs, alpha):
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
return o
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False): time_fn(sdpa_fused, q, k, v, scale)
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
n_kv_heads = k.shape[-3]
n_repeats = n_q_heads // n_kv_heads
B = q.shape[0]
L = q.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
k = mx.expand_dims(k, 2)
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if f32softmax:
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
else:
scores = mx.softmax(scores, axis=-1)
out = scores @ v
if n_repeats > 1:
out = mx.reshape(out, [B, n_q_heads, L, -1])
return out
def mlx_spda_unfused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def mlx_spda_fused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
shape_q = (
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
)
shape_kv = (
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
)
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
scale = math.sqrt(1.0 / head_dim)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
if transpose:
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
)
return time_mlx_fused, time_mlx_unfused
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks") parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
dtypes = ("float16", "float32")[:1] time_self_attention_sdpa()
transposes = (False,) time_self_attention_primitives()
# fmt: off
shapes_64 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 32, 32, 64, 32, 32),
( 1, 64, 64, 64, 32, 32),
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 32),
( 1, 2048, 2048, 64, 32, 32),
( 1, 4096, 4096, 64, 32, 32),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 32),
( 1, 2048, 2048, 80, 32, 32),
( 1, 4096, 4096, 80, 32, 32),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 32),
( 1, 2048, 2048, 128, 32, 32),
( 1, 4096, 4096, 128, 32, 32),
)
# fmt: on
shapes = shapes_64 + shapes_80 + shapes_128
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
for dtype in dtypes:
for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
np_dtype = getattr(np, dtype)
time_mlx_fused, time_mlx_unfused = bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)

View File

@@ -4,51 +4,42 @@ import math
import mlx.core as mx import mlx.core as mx
from time_utils import time_fn from time_utils import time_fn
L = 16384 L = 1024
H = 32 H = 32
H_k = H // 4 H_k = 32 // 4
D = 128 D = 128
dtype = mx.float16
loops = 10
def attention(q, k, v): def attention(q, k, v):
def _sdpa(q, k, v): B, Hq, L, D = q.shape
B, Hq, L, D = q.shape _, Hk, S, _ = k.shape
_, Hk, S, _ = k.shape q = q.reshape(B, Hk, Hq // Hk, L, D)
q = q.reshape(B, Hk, Hq // Hk, L, D) k = k[:, :, None, :, :]
k = k[:, :, None, :, :] v = v[:, :, None, :, :]
v = v[:, :, None, :, :] s = q @ k.transpose(0, 1, 2, 4, 3)
s = q @ k.transpose(0, 1, 2, 4, 3) p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) o = p @ v
o = p @ v return o.reshape(B, Hq, L, D)
return o.reshape(B, Hq, L, D)
for i in range(loops):
q = _sdpa(q, k, v)
return q
def sdpa(q, k, v): def sdpa(q, k, v):
for i in range(loops): return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
return q
def time_self_attention_primitives(): def time_self_attention_primitives():
mx.random.seed(3) mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v) mx.eval(q, k, v)
time_fn(attention, q, k, v) time_fn(attention, q, k, v)
def time_self_attention_sdpa(): def time_self_attention_sdpa():
mx.random.seed(3) mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v) mx.eval(q, k, v)
time_fn(sdpa, q, k, v) time_fn(sdpa, q, k, v)

View File

@@ -494,7 +494,7 @@ below.
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to // Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal // those in the kernel declaration at axpby.metal
@@ -509,14 +509,14 @@ below.
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode alpha and beta // Encode alpha and beta
compute_encoder.set_bytes(alpha_, 3); compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder.set_bytes(beta_, 4); compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim // Encode shape, strides and ndim
compute_encoder.set_vector_bytes(x.shape(), 5); compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder.set_vector_bytes(x.strides(), 6); compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder.set_bytes(y.strides(), 7); compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
compute_encoder.set_bytes(ndim, 8); compute_encoder->setBytes(&ndim, sizeof(int), 8);
// We launch 1 thread for each input and make sure that the number of // We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed // threads in any given threadgroup is not higher than the max allowed
@@ -530,7 +530,7 @@ below.
// Launch the grid with the given number of threads divided among // Launch the grid with the given number of threads divided among
// the given threadgroups // the given threadgroups
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
We can now call the :meth:`axpby` operation on both the CPU and the GPU! We can now call the :meth:`axpby` operation on both the CPU and the GPU!

View File

@@ -209,7 +209,7 @@ Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots. Metal kernel cache persists accross reboots.
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^

View File

@@ -12,4 +12,5 @@ Fast
layer_norm layer_norm
rope rope
scaled_dot_product_attention scaled_dot_product_attention
affine_quantize
metal_kernel metal_kernel

View File

@@ -12,7 +12,6 @@ Layers
ALiBi ALiBi
AvgPool1d AvgPool1d
AvgPool2d AvgPool2d
AvgPool3d
BatchNorm BatchNorm
CELU CELU
Conv1d Conv1d
@@ -42,7 +41,6 @@ Layers
LSTM LSTM
MaxPool1d MaxPool1d
MaxPool2d MaxPool2d
MaxPool3d
Mish Mish
MultiHeadAttention MultiHeadAttention
PReLU PReLU

View File

@@ -184,8 +184,8 @@ Let's time these two different versions:
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100)) print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100)) print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
vectorized version takes only ``0.024`` seconds, more than 200 times faster. vectorized version takes only ``0.025`` seconds, more than ten times faster.
Of course, this operation is quite contrived. A better approach is to simply do Of course, this operation is quite contrived. A better approach is to simply do
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy. ``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.

View File

@@ -257,7 +257,7 @@ void Axpby::eval_gpu(
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to // Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal // those in the kernel declaration at axpby.metal
@@ -272,15 +272,15 @@ void Axpby::eval_gpu(
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode alpha and beta // Encode alpha and beta
compute_encoder.set_bytes(alpha_, 3); compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder.set_bytes(beta_, 4); compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim if needed // Encode shape, strides and ndim if needed
if (!contiguous_kernel) { if (!contiguous_kernel) {
compute_encoder.set_vector_bytes(x.shape(), 5); compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder.set_vector_bytes(x.strides(), 6); compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder.set_bytes(y.strides(), 7); compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
compute_encoder.set_bytes(ndim, 8); compute_encoder->setBytes(&ndim, sizeof(int), 8);
} }
// We launch 1 thread for each input and make sure that the number of // We launch 1 thread for each input and make sure that the number of
@@ -295,7 +295,7 @@ void Axpby::eval_gpu(
// Launch the grid with the given number of threads divided among // Launch the grid with the given number of threads divided among
// the given threadgroups // the given threadgroups
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
#else // Metal is not available #else // Metal is not available

View File

@@ -2,6 +2,7 @@
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
template <typename T> template <typename T>
@@ -59,4 +60,4 @@ template <typename T>
instantiate_axpby(float32, float); instantiate_axpby(float32, float);
instantiate_axpby(float16, half); instantiate_axpby(float16, half);
instantiate_axpby(bfloat16, bfloat16_t); instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t); instantiate_axpby(complex64, complex64_t);

View File

@@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.24 cmake>=3.24
mlx>=0.21.0 mlx>=0.18.1
nanobind==2.2.0 nanobind==2.2.0

View File

@@ -28,19 +28,10 @@ endif()
if (@MLX_BUILD_METAL@) if (@MLX_BUILD_METAL@)
set(MLX_BUILD_METAL @MLX_BUILD_METAL@) set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_) set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
set(MLX_INCLUDE_DIRS set_and_check(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};" ${MLX_INCLUDE_DIRS}
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
) )
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
else()
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
endif()
endif() endif()
set_target_properties(mlx PROPERTIES set_target_properties(mlx PROPERTIES
@@ -49,4 +40,4 @@ set_target_properties(mlx PROPERTIES
) )
include(FindPackageHandleStandardArgs) include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS) find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)

View File

@@ -19,7 +19,7 @@ Buffer malloc(size_t size) {
} }
void free(Buffer buffer) { void free(Buffer buffer) {
allocator().free(buffer); return allocator().free(buffer);
} }
Buffer CommonAllocator::malloc(size_t size, bool) { Buffer CommonAllocator::malloc(size_t size, bool) {

View File

@@ -214,8 +214,6 @@ array::~array() {
if (do_detach) { if (do_detach) {
for (auto& s : siblings()) { for (auto& s : siblings()) {
for (auto& ss : s.siblings()) { for (auto& ss : s.siblings()) {
// Set to null here to avoid descending into array destructor
// for siblings
ss.array_desc_ = nullptr; ss.array_desc_ = nullptr;
} }
s.array_desc_->siblings.clear(); s.array_desc_->siblings.clear();
@@ -294,14 +292,6 @@ array::ArrayDesc::~ArrayDesc() {
auto top = std::move(for_deletion.back()); auto top = std::move(for_deletion.back());
for_deletion.pop_back(); for_deletion.pop_back();
append_deletable_inputs(*top); append_deletable_inputs(*top);
// Clear out possible siblings to break circular references
for (auto& s : top->siblings) {
// Set to null here to avoid descending into top-level
// array destructor for siblings
s.array_desc_ = nullptr;
}
top->siblings.clear();
} }
} }

View File

@@ -39,7 +39,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
// rely on data_size anyway. // rely on data_size anyway.
size_t data_size = out.size(); size_t data_size = out.size();
return move_or_copy(in, out, strides_, flags, data_size, offset_); return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
} }
void Broadcast::eval(const std::vector<array>& inputs, array& out) { void Broadcast::eval(const std::vector<array>& inputs, array& out) {
@@ -58,12 +58,12 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
if (out.size() > in.size()) { if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false; flags.row_contiguous = flags.col_contiguous = false;
} }
move_or_copy(in, out, strides, flags, in.data_size()); out.copy_shared_buffer(in, strides, flags, in.data_size());
} }
void Copy::eval(const std::vector<array>& inputs, array& out) { void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
move_or_copy(inputs[0], out); out.copy_shared_buffer(inputs[0]);
} }
void CustomTransforms::eval( void CustomTransforms::eval(
@@ -72,7 +72,7 @@ void CustomTransforms::eval(
assert(inputs.size() > outputs.size()); assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size(); for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) { i++, j++) {
move_or_copy(inputs[j], outputs[i]); outputs[i].copy_shared_buffer(inputs[j]);
} }
} }
@@ -81,7 +81,7 @@ void Depends::eval(
std::vector<array>& outputs) { std::vector<array>& outputs) {
assert(inputs.size() > outputs.size()); assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) { for (int i = 0; i < outputs.size(); i++) {
move_or_copy(inputs[i], outputs[i]); outputs[i].copy_shared_buffer(inputs[i]);
} }
} }
@@ -194,7 +194,7 @@ void Reshape::shared_buffer_reshape(
auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
} }
move_or_copy(in, out, out_strides, flags, in.data_size()); out.copy_shared_buffer(in, out_strides, flags, in.data_size());
} }
void Split::eval( void Split::eval(
@@ -263,7 +263,7 @@ std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
void StopGradient::eval(const std::vector<array>& inputs, array& out) { void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
move_or_copy(inputs[0], out); out.copy_shared_buffer(inputs[0]);
} }
void Transpose::eval(const std::vector<array>& inputs, array& out) { void Transpose::eval(const std::vector<array>& inputs, array& out) {
@@ -297,7 +297,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
b_stride *= out.shape(ri); b_stride *= out.shape(ri);
} }
} }
move_or_copy(in, out, out_strides, flags, in.data_size()); out.copy_shared_buffer(in, out_strides, flags, in.data_size());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -279,7 +279,7 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using // Figure out which kernel we are using
auto& shape = outputs[0].shape(); auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape); bool contiguous = compiled_check_contiguity(inputs, shape);
// Handle all broadcasting and collect function input arguments // Handle all broadcasting and collect function input arguments
std::vector<void*> args; std::vector<void*> args;

View File

@@ -159,17 +159,6 @@ void Conjugate::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General);
}
}
void Cos::eval(const std::vector<array>& inputs, array& out) { void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
@@ -617,7 +606,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 || if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
in.flags().row_contiguous) { in.flags().row_contiguous) {
auto strides = in.strides(); auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) { for (int i = 0; i < strides.size() - 1; ++i) {
strides[i] *= ibytes; strides[i] *= ibytes;
strides[i] /= obytes; strides[i] /= obytes;
} }

View File

@@ -2,38 +2,13 @@
#include <cassert> #include <cassert>
#include "mlx/backend/common/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
namespace { namespace {
template <typename T, int bits>
void extract_bits(const uint8_t* w_in, T* w_out) {
assert(bits == 3 || bits == 6);
if (bits == 3) {
w_out[0] = static_cast<T>(w_in[0] & 0x7);
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
w_out[2] = static_cast<T>(((w_in[0] & 0xc0) >> 6) + ((w_in[1] & 0x1) << 2));
w_out[3] = static_cast<T>((w_in[1] & 0xe) >> 1);
w_out[4] = static_cast<T>((w_in[1] & 0x70) >> 4);
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
} else if (bits == 6) {
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
w_out[1] =
static_cast<T>(((w_in[0] >> 6) & 0x03) + ((w_in[1] & 0x0f) << 2));
w_out[2] =
static_cast<T>(((w_in[1] >> 4) & 0x0f) + ((w_in[2] & 0x03) << 4));
w_out[3] = static_cast<T>((w_in[2] >> 2) & 0x3f);
}
}
template <typename T, int bits, int group_size> template <typename T, int bits, int group_size>
void _qmm( void _qmm(
T* result, T* result,
@@ -45,12 +20,13 @@ void _qmm(
int N, int N,
int K) { int K) {
constexpr int bitmask = (1 << bits) - 1; constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int pack_factor = 32 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int packs_in_group = group_size / pack_factor; constexpr int packs_in_group = group_size / pack_factor;
const int Ng = N / group_size;
const int Nw = N / pack_factor;
for (int m = 0; m < M; m++) { for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w; const uint32_t* w_local = w;
const T* scales_local = scales; const T* scales_local = scales;
const T* biases_local = biases; const T* biases_local = biases;
@@ -64,25 +40,13 @@ void _qmm(
T scale = *scales_local++; T scale = *scales_local++;
T bias = *biases_local++; T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) { for (int ng = 0; ng < packs_in_group; ng++) {
if (bits == 3 || bits == 6) { uint32_t wi = *w_local++;
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) += xi * (scale * wl[p] + bias);
}
w_local += bytes_per_pack;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) { for (int p = 0; p < pack_factor; p++) {
(*result_local++) += (*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias); xi * (scale * static_cast<T>(wi & bitmask) + bias);
if (bits != 8) { wi >>= bits;
wi >>= bits;
}
}
} }
} }
} }
@@ -103,12 +67,13 @@ void _qmm_t(
int N, int N,
int K) { int K) {
constexpr int bitmask = (1 << bits) - 1; constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int pack_factor = 32 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int packs_in_group = group_size / pack_factor; constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) { for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w; const uint32_t* w_local = w;
const T* scales_local = scales; const T* scales_local = scales;
const T* biases_local = biases; const T* biases_local = biases;
@@ -120,26 +85,12 @@ void _qmm_t(
T bias = *biases_local++; T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) { for (int kw = 0; kw < packs_in_group; kw++) {
if (bits == 3 || bits == 6) { uint32_t wi = *w_local++;
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum += x_local[p] * (scale * wl[p] + bias);
}
w_local += bytes_per_pack;
x_local += pack_factor;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) { for (int p = 0; p < pack_factor; p++) {
sum += sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
(*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias); wi >>= bits;
if (bits != 8) {
wi >>= bits;
}
}
} }
} }
} }
@@ -151,55 +102,6 @@ void _qmm_t(
} }
} }
template <typename T, int bits, int group_size>
void _qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
}
}
template <typename T, int bits>
void _qmm_dispatch_group(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
int group_size,
bool transposed_w) {
switch (group_size) {
case 32:
_qmm_dispatch_transpose<T, bits, 32>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 64:
_qmm_dispatch_transpose<T, bits, 64>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 128:
_qmm_dispatch_transpose<T, bits, 128>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
default:
throw std::invalid_argument(
"Quantization group size must be 32, 64 or 128.");
}
}
template <typename T> template <typename T>
void _qmm_dispatch_typed( void _qmm_dispatch_typed(
T* result, T* result,
@@ -214,29 +116,79 @@ void _qmm_dispatch_typed(
int bits, int bits,
bool transposed_w) { bool transposed_w) {
switch (bits) { switch (bits) {
case 2: case 2: {
_qmm_dispatch_group<T, 2>( switch (group_size) {
result, x, w, scales, biases, M, N, K, group_size, transposed_w); case 32:
break; if (transposed_w) {
case 3: return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
_qmm_dispatch_group<T, 3>( } else {
result, x, w, scales, biases, M, N, K, group_size, transposed_w); return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
break; }
case 4: case 64:
_qmm_dispatch_group<T, 4>( if (transposed_w) {
result, x, w, scales, biases, M, N, K, group_size, transposed_w); return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
break; } else {
case 6: return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
_qmm_dispatch_group<T, 6>( }
result, x, w, scales, biases, M, N, K, group_size, transposed_w); case 128:
break; if (transposed_w) {
case 8: return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
_qmm_dispatch_group<T, 8>( } else {
result, x, w, scales, biases, M, N, K, group_size, transposed_w); return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
break; }
default: }
throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8."); }
case 4: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 8: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
} }
std::ostringstream msg;
msg << "Quantization type not supported. Provided bits=" << bits
<< " and group_size=" << group_size
<< ". The supported options are bits in "
<< "{2, 4, 8} and group_size in {64, 128}.";
throw std::invalid_argument(msg.str());
} }
void _qmm_dispatch( void _qmm_dispatch(
@@ -452,114 +404,4 @@ void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
transpose_); transpose_);
} }
template <typename T, typename U>
void quantize(
const array& w_,
array& out_,
array& scales_,
array& biases_,
int bits,
int group_size) {
const T* w = w_.data<T>();
auto out = out_.data<U>();
T* scales = scales_.data<T>();
T* biases = biases_.data<T>();
T n_bins = (1 << bits) - 1;
T eps = 1e-7;
bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
int bytes_per_pack = power_of_2_bits ? 1 : 3;
int int_per_group = group_size * bytes_per_pack / el_per_int;
size_t n_groups = w_.size() / group_size;
for (size_t i = 0; i < n_groups; ++i) {
size_t w_idx = i * group_size;
T w_min = std::numeric_limits<float>::infinity();
T w_max = -w_min;
for (int j = 0; j < group_size; ++j) {
w_max = std::max(w_max, w[w_idx + j]);
w_min = std::min(w_min, w[w_idx + j]);
}
bool mask = std::abs(w_min) > std::abs(w_max);
T scale = std::max(T((w_max - w_min) / n_bins), eps);
scale = mask ? scale : -scale;
auto edge = mask ? w_min : w_max;
auto q0 = std::rint(edge / scale);
if (q0 == 0) {
scales[i] = scale;
biases[i] = 0;
} else {
scales[i] = edge / q0;
biases[i] = edge;
}
size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
uint32_t out_el = 0;
for (int k = 0; k < el_per_int; ++k) {
T w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - biases[i]) / scales[i]);
w_el = std::min(std::max(w_el, T(0)), n_bins);
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
}
if (power_of_2_bits) {
out[out_idx + j] = out_el;
} else {
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
}
}
}
}
void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto ensure_row_contiguous = [](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto w = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -120,53 +120,45 @@ struct MinReduce {
}; };
template <typename InT> template <typename InT>
void reduce_dispatch_and_or( void reduce_dispatch_out(
const array& in, const array& in,
array& out, array& out,
Reduce::ReduceType rtype, Reduce::ReduceType rtype,
const std::vector<int>& axes) { const std::vector<int>& axes) {
if (rtype == Reduce::And) { switch (rtype) {
reduction_op<InT, bool>(in, out, axes, true, AndReduce()); case Reduce::And: {
} else { reduction_op<InT, bool>(in, out, axes, true, AndReduce());
reduction_op<InT, bool>(in, out, axes, false, OrReduce()); break;
}
}
template <typename InT>
void reduce_dispatch_sum_prod(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Sum) {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
} }
} else { case Reduce::Or: {
auto op = [](auto y, auto x) { (*y) *= x; }; reduction_op<InT, bool>(in, out, axes, false, OrReduce());
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) { break;
reduction_op<InT, int32_t>(in, out, axes, 1, op); }
} else { case Reduce::Sum: {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if (out.dtype() == int32) {
// special case since the input type can be bool
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
}
break;
}
case Reduce::Prod: {
auto op = [](auto y, auto x) { (*y) *= x; };
reduction_op<InT, InT>(in, out, axes, 1, op); reduction_op<InT, InT>(in, out, axes, 1, op);
break;
}
case Reduce::Max: {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
break;
}
case Reduce::Min: {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
break;
} }
}
}
template <typename InT>
void reduce_dispatch_min_max(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
} }
} }
@@ -198,114 +190,46 @@ void nd_loop(
void Reduce::eval(const std::vector<array>& inputs, array& out) { void Reduce::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (reduce_type_) { switch (in.dtype()) {
case Reduce::And: case bool_:
case Reduce::Or: { reduce_dispatch_out<bool>(in, out, reduce_type_, axes_);
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
break; break;
} case uint8:
case Reduce::Sum: reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break; break;
} case uint16:
case Reduce::Max: reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
case Reduce::Min: { break;
switch (in.dtype()) { case uint32:
case bool_: reduce_dispatch_out<uint32_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_); break;
break; case uint64:
case uint8: reduce_dispatch_out<uint64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_); break;
break; case int8:
case uint16: reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_); break;
break; case int16:
case uint32: reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_); break;
break; case int32:
case uint64: reduce_dispatch_out<int32_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_); break;
break; case int64:
case int8: reduce_dispatch_out<int64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_); break;
break; case float16:
case int16: reduce_dispatch_out<float16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_); break;
break; case float32:
case int32: reduce_dispatch_out<float>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_); break;
break; case bfloat16:
case int64: reduce_dispatch_out<bfloat16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_); break;
break; case complex64:
case float16: reduce_dispatch_out<complex64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break; break;
}
} }
} }

View File

@@ -34,7 +34,7 @@ void shared_buffer_slice(
flags.col_contiguous = is_col_contiguous; flags.col_contiguous = is_col_contiguous;
flags.contiguous = (no_bsx_size == data_size); flags.contiguous = (no_bsx_size == data_size);
move_or_copy(in, out, out_strides, flags, data_size, data_offset); out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -4,28 +4,6 @@
namespace mlx::core { namespace mlx::core {
void move_or_copy(const array& in, array& out) {
if (in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.copy_shared_buffer(in);
}
}
void move_or_copy(
const array& in,
array& out,
const std::vector<size_t>& strides,
array::Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
if (in.is_donatable()) {
out.move_shared_buffer(in, strides, flags, data_size, offset);
} else {
out.copy_shared_buffer(in, strides, flags, data_size, offset);
}
}
template <typename StrideT> template <typename StrideT>
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>> std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
collapse_contiguous_dims_impl( collapse_contiguous_dims_impl(

View File

@@ -178,13 +178,4 @@ inline bool is_donatable(const array& in, const array& out) {
in.buffer_size() <= out.nbytes() + donation_extra; in.buffer_size() <= out.nbytes() + donation_extra;
} }
void move_or_copy(const array& in, array& out);
void move_or_copy(
const array& in,
array& out,
const std::vector<size_t>& strides,
array::Flags flags,
size_t data_size,
size_t offset = 0);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -14,21 +14,14 @@ function(make_jit_source SRC_FILE)
COMMAND COMMAND
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh /bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
${SRC_FILE} ${SRC_FILE} "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN}) DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp) add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
add_dependencies(mlx ${SRC_NAME}) add_dependencies(mlx ${SRC_NAME})
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
endfunction(make_jit_source) endfunction(make_jit_source)
make_jit_source( make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h)
utils
kernels/jit/bf16.h
kernels/metal_3_0/bf16.h
kernels/metal_3_1/bf16.h
kernels/bf16_math.h
kernels/complex.h
kernels/defines.h)
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h) make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
make_jit_source(binary_ops) make_jit_source(binary_ops)
make_jit_source(ternary_ops) make_jit_source(ternary_ops)

View File

@@ -242,9 +242,6 @@ void MetalAllocator::clear_cache() {
void MetalAllocator::free(Buffer buffer) { void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr()); auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (buf == nullptr) {
return;
}
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
residency_set_.erase(buf); residency_set_.erase(buf);
active_memory_ -= buf->length(); active_memory_ -= buf->length();

View File

@@ -22,37 +22,37 @@ std::string get_kernel_name(
BinaryOpType bopt, BinaryOpType bopt,
const std::string& op, const std::string& op,
const array& a, const array& a,
bool large, bool use_2d,
int ndim, int ndim,
int work_per_thread) { int work_per_thread) {
std::string kname; std::ostringstream kname;
switch (bopt) { switch (bopt) {
case BinaryOpType::ScalarScalar: case BinaryOpType::ScalarScalar:
kname = "ss"; kname << "ss";
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
kname = (large ? "sv2" : "sv"); kname << (use_2d ? "sv2" : "sv");
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
kname = (large ? "vs2" : "vs"); kname << (use_2d ? "vs2" : "vs");
break; break;
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
kname = (large ? "vv2" : "vv"); kname << (use_2d ? "vv2" : "vv");
break; break;
case BinaryOpType::General: case BinaryOpType::General:
kname = "g"; kname << "g";
if (ndim <= 3) { if (ndim <= 3) {
kname += std::to_string(ndim); kname << ndim;
} else { } else {
concatenate(kname, "n", std::to_string(work_per_thread)); kname << "n";
} if (work_per_thread > 1) {
if (large) { kname << work_per_thread;
kname += "large"; }
} }
break; break;
} }
concatenate(kname, "_", op, type_to_name(a)); kname << "_" << op << type_to_name(a);
return kname; return kname.str();
} }
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
@@ -81,23 +81,18 @@ void binary_op_gpu_inplace(
}; };
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse(); auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
bool large = out.data_size() > UINT32_MAX; bool use_2d = out.data_size() > UINT32_MAX;
auto ndim = shape.size(); auto ndim = shape.size();
int work_per_thread; int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1;
if (bopt == BinaryOpType::General) {
work_per_thread = large ? 4 : 2;
} else {
work_per_thread = 1;
}
std::string kernel_name = std::string kernel_name =
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto kernel = outputs.size() == 2 auto kernel = outputs.size() == 2
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op) ? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op); : get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// - If a is donated it goes to the first output // - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated // - If b is donated it goes to the first output if a was not donated
@@ -122,15 +117,19 @@ void binary_op_gpu_inplace(
size_t rest = out.size() / (dim0 * dim1); size_t rest = out.size() / (dim0 * dim1);
if (ndim > 3) { if (ndim > 3) {
compute_encoder.set_vector_bytes(shape, arg_idx++); compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
compute_encoder.set_vector_bytes(strides_a, arg_idx++); compute_encoder->setBytes(
compute_encoder.set_vector_bytes(strides_b, arg_idx++); strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder.set_bytes<int>(ndim, arg_idx++); compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} else { } else {
// The shape is implicit in the grid for <= 3D // The shape is implicit in the grid for <= 3D
compute_encoder.set_vector_bytes(strides_a, arg_idx++); compute_encoder->setBytes(
compute_encoder.set_vector_bytes(strides_b, arg_idx++); strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
} }
if (thread_group_size != 1024) { if (thread_group_size != 1024) {
@@ -138,7 +137,7 @@ void binary_op_gpu_inplace(
} }
auto group_dims = get_block_dims(dim0, dim1, rest); auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} else { } else {
// Launch a 1D or 2D grid of threads // Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size(); size_t nthreads = out.data_size();
@@ -146,9 +145,9 @@ void binary_op_gpu_inplace(
thread_group_size = nthreads; thread_group_size = nthreads;
} }
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
} }

View File

@@ -1,6 +1,5 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h>
#include <iostream> //TODO
#include <sstream> #include <sstream>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
@@ -12,12 +11,12 @@
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
using namespace fmt::literals;
namespace mlx::core { namespace mlx::core {
constexpr int WORK_PER_THREAD = 4;
inline void build_kernel( inline void build_kernel(
std::string& os, std::ostream& os,
const std::string& kernel_name, const std::string& kernel_name,
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<array>& outputs, const std::vector<array>& outputs,
@@ -42,8 +41,8 @@ inline void build_kernel(
int cnt = 0; int cnt = 0;
// Start the kernel // Start the kernel
os += fmt::format( os << "[[host_name(\"" << kernel_name << "\")]]\n"
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name); << "[[kernel]] void " << kernel_name << "(\n";
// Add the input arguments // Add the input arguments
for (auto& x : inputs) { for (auto& x : inputs) {
@@ -55,61 +54,51 @@ inline void build_kernel(
} }
// Scalars and contiguous need no strides // Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) { if (is_scalar(x) || contiguous) {
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]],\n";
} else {
add_indices = true; add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]],\n";
} }
os += fmt::format(
" device const {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()),
xname,
cnt++);
} }
if (add_indices) { if (add_indices) {
os += fmt::format( os << " constant const size_t* in_strides [[buffer(" << cnt++
" constant const size_t* in_strides [[buffer({0})]],\n", cnt++); << ")]],\n";
} }
// Add the output arguments // Add the output arguments
for (auto& x : outputs) { for (auto& x : outputs) {
os += fmt::format( os << " device " << get_type_string(x.dtype()) << "* "
" device {0}* {1} [[buffer({2})]],\n", << namer.get_name(x) << " [[buffer(" << cnt++ << ")]],\n";
get_type_string(x.dtype()),
namer.get_name(x),
cnt++);
} }
// Add output strides and shape to extract the indices. // Add output strides and shape to extract the indices.
if (!contiguous) { if (!contiguous) {
os += fmt::format( os << " constant const size_t* output_strides [[buffer(" << cnt++
" constant const size_t* output_strides [[buffer({0})]],\n", cnt++); << ")]],\n"
os += fmt::format( << " constant const int* output_shape [[buffer(" << cnt++ << ")]],\n";
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
} }
if (dynamic_dims) { if (dynamic_dims) {
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++); os << " constant const int& ndim [[buffer(" << cnt++ << ")]],\n";
} }
// The thread index in the whole grid // The thread index in the whole grid
os += " uint3 pos [[thread_position_in_grid]],\n"; os << " uint3 pos [[thread_position_in_grid]],\n"
os += " uint3 grid [[threads_per_grid]]) {\n"; << " uint3 grid [[threads_per_grid]]) {\n";
std::string idx_type = use_big_index ? "size_t" : "uint"; if (use_big_index) {
if (contiguous && use_big_index) {
// This is only used for contiguous kernels which don't have // This is only used for contiguous kernels which don't have
// a third grid dimension // a third grid dimension
os += " size_t index = pos.x + grid.x * size_t(pos.y);\n"; os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
} else if (work_per_thread > 1) { } else if (work_per_thread > 1) {
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n"
os += fmt::format( << " int xshape = output_shape["
" int xshape = output_shape[{0}];\n", << (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)); << " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
os += fmt::format(
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
} else { } else {
os += fmt::format( os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
} }
// Read constant / contiguous inputs in tmps // Read constant / contiguous inputs in tmps
@@ -120,19 +109,16 @@ inline void build_kernel(
if (is_constant(x)) { if (is_constant(x)) {
auto type_str = get_type_string(x.dtype()); auto type_str = get_type_string(x.dtype());
std::ostringstream ss; os << " auto tmp_" << xname << " = static_cast<"
print_constant(ss, x); << get_type_string(x.dtype()) << ">(";
os += fmt::format( print_constant(os, x);
" auto tmp_{0} = static_cast<{1}>({2});\n", os << ");\n";
xname,
get_type_string(x.dtype()),
ss.str());
} else if (is_scalar(x)) { } else if (is_scalar(x)) {
os += fmt::format( os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
" {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname); << xname << "[0];\n";
} else if (contiguous) { } else if (contiguous) {
os += fmt::format( os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
" {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname); << xname << "[index];\n";
} else { } else {
nc_inputs.push_back(x); nc_inputs.push_back(x);
} }
@@ -141,98 +127,83 @@ inline void build_kernel(
// Initialize the indices for non-contiguous inputs // Initialize the indices for non-contiguous inputs
for (int i = 0; i < nc_inputs.size(); ++i) { for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]); auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
if (ndim == 1) { if (ndim == 1) {
int offset = i * ndim; int offset = i * ndim;
os += fmt::format( os << " size_t index_" << xname << " = elem_to_loc_1(pos.x, "
"elem_to_loc_1<size_t, uint>(pos.x, in_strides[{0}]);\n", offset); << "in_strides[" << offset << "]);\n";
} else if (ndim == 2) { } else if (ndim == 2) {
int offset = i * ndim; int offset = i * ndim;
os += fmt::format( os << " size_t index_" << xname << " = elem_to_loc_2({pos.x, pos.y}, "
"elem_to_loc_2<size_t, {0}>({{pos.x, pos.y}}, in_strides + {1});\n", << "in_strides + " << offset << ");\n";
idx_type,
offset);
} else if (ndim == 3) { } else if (ndim == 3) {
int offset = i * ndim; int offset = i * ndim;
os += fmt::format( os << " size_t index_" << xname << " = elem_to_loc_3(pos, "
"elem_to_loc_3<size_t, {0}>(pos, in_strides + {1});\n", << "in_strides + " << offset << ");\n";
idx_type,
offset);
} else if (!dynamic_dims) { } else if (!dynamic_dims) {
int offset = (i + 1) * ndim; int offset = i * ndim;
os += fmt::format( os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n", << offset + ndim - 1 << "]"
idx_type, << " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
offset - 1,
offset - 2);
} else { } else {
os += fmt::format( os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n", << i << " + ndim - 1]"
idx_type, << " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
i);
} }
} }
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) { if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
os += " uint zpos = pos.z;\n"; os << " uint zpos = pos.z;\n";
if (dynamic_dims) { if (dynamic_dims) {
os += " for (int d = ndim - 3; d >= 0; --d) {\n"; os << " for (int d = ndim - 3; d >= 0; --d) {\n";
} else { } else {
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3); os << " for (int d = " << ndim - 3 << "; d >= 0; --d) {\n";
} }
os += " uint l = zpos % output_shape[d];\n"; os << " uint l = zpos % output_shape[d];\n";
for (int i = 0; i < nc_inputs.size(); ++i) { for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]); auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" index_{0} += ", xname); os << " index_" << xname << " += ";
if (dynamic_dims) { if (dynamic_dims) {
os += os << "l * in_strides[" << i << " * ndim + d];\n";
fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
} else { } else {
os += os << "l * in_strides[" << i * ndim << " + d];\n";
fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
} }
} }
os += " zpos /= output_shape[d];\n }\n"; os << " zpos /= output_shape[d];\n }\n";
} }
// Open per-thread loop // Open per-thread loop
if (work_per_thread > 1) { if (work_per_thread > 1) {
os += os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
} }
// Read non-contiguous inputs into tmps // Read non-contiguous inputs into tmps
for (int i = 0; i < nc_inputs.size(); ++i) { for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i]; auto& x = nc_inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
os += fmt::format( os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
" {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname); << xname << "[index_" << xname << "];\n";
} }
// Actually write the computation // Actually write the computation
for (auto& x : tape) { for (auto& x : tape) {
os += fmt::format( os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
" {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x)); << " = ";
if (is_static_cast(x.primitive())) { if (is_static_cast(x.primitive())) {
os += fmt::format( os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
"static_cast<{0}>(tmp_{1});\n", << namer.get_name(x.inputs()[0]) << ");\n";
get_type_string(x.dtype()),
namer.get_name(x.inputs()[0]));
} else { } else {
std::ostringstream ss; x.primitive().print(os);
x.primitive().print(ss); os << "()(";
os += ss.str();
os += "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) { for (int i = 0; i < x.inputs().size() - 1; i++) {
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i])); os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
} }
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back())); os << "tmp_" << namer.get_name(x.inputs().back()) << ");\n";
} }
} }
// Write the outputs from tmps // Write the outputs from tmps
for (auto& x : outputs) { for (auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
<< ";\n";
} }
// Increment indices and close per thread loop // Increment indices and close per thread loop
if (work_per_thread > 1) { if (work_per_thread > 1) {
@@ -240,18 +211,18 @@ inline void build_kernel(
auto& x = nc_inputs[i]; auto& x = nc_inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
if (!dynamic_dims) { if (!dynamic_dims) {
os += fmt::format( os << " index_" << xname << " += "
" index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1); << "in_strides[" << i * ndim + ndim - 1 << "];\n";
} else { } else {
os += fmt::format( os << " index_" << xname << " += "
" index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i); << "in_strides[" << i << " * ndim + ndim - 1];\n";
} }
} }
os += " index++;\n }\n"; os << " index++;\n }\n";
} }
// Finish the kernel // Finish the kernel
os += "}\n"; os << "}\n";
if (cnt > 31) { if (cnt > 31) {
std::ostringstream msg; std::ostringstream msg;
@@ -275,9 +246,9 @@ void Compiled::eval_gpu(
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_, [&]() { auto lib = d.get_library(kernel_lib_, [&]() {
std::string kernel = metal::utils(); std::ostringstream kernel;
concatenate( kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops()); << metal::ternary_ops();
build_kernel( build_kernel(
kernel, kernel,
kernel_lib_ + "_contiguous", kernel_lib_ + "_contiguous",
@@ -290,7 +261,7 @@ void Compiled::eval_gpu(
/* dynamic_dims = */ false); /* dynamic_dims = */ false);
build_kernel( build_kernel(
kernel, kernel,
kernel_lib_ + "_contiguous_large", kernel_lib_ + "_contiguous_big",
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
@@ -311,21 +282,7 @@ void Compiled::eval_gpu(
/* ndim = */ i, /* ndim = */ i,
/* dynamic_dims = */ false, /* dynamic_dims = */ false,
/* use_big_index = */ false, /* use_big_index = */ false,
/* work_per_thread = */ i > 3 ? 2 : 1); /* work_per_thread = */ i > 3 ? WORK_PER_THREAD : 1);
if (i > 1) {
build_kernel(
kernel,
kernel_lib_ + "_strided_" + std::to_string(i) + "_large",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ i,
/* dynamic_dims = */ false,
/* use_big_index = */ true,
/* work_per_thread = */ i > 3 ? 4 : 1);
}
} }
build_kernel( build_kernel(
kernel, kernel,
@@ -338,25 +295,13 @@ void Compiled::eval_gpu(
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ true, /* dynamic_dims = */ true,
/* use_big_index = */ false, /* use_big_index = */ false,
/* work_per_thread = */ 2); /* work_per_thread = */ WORK_PER_THREAD);
build_kernel( return kernel.str();
kernel,
kernel_lib_ + "_strided_dynamic_large",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ 0,
/* dynamic_dims = */ true,
/* use_big_index = */ true,
/* work_per_thread = */ 4);
return kernel;
}); });
// Figure out which kernel we are using // Figure out which kernel we are using
auto& output_shape = outputs[0].shape(); auto& output_shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, output_shape); bool contiguous = compiled_check_contiguity(inputs, output_shape);
// Collapse contiguous dims to route to a faster kernel if possible. Also // Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting. // handle all broadcasting.
@@ -404,19 +349,13 @@ void Compiled::eval_gpu(
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX); collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
} }
bool large; bool use_2d = false;
if (contiguous) { if (contiguous) {
size_t max_size = 0; size_t max_size = 0;
for (auto& in : inputs) { for (auto& in : inputs) {
max_size = std::max(max_size, in.data_size()); max_size = std::max(max_size, in.data_size());
} }
large = (max_size > UINT32_MAX); use_2d = (max_size > UINT32_MAX);
} else {
size_t max_size = 0;
for (auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
large = (max_size > UINT32_MAX);
} }
// Get the kernel from the lib // Get the kernel from the lib
@@ -429,13 +368,12 @@ void Compiled::eval_gpu(
} else { } else {
kernel_name += std::to_string(shape.size()); kernel_name += std::to_string(shape.size());
} }
} } else if (use_2d) {
if (large) { kernel_name += "_big";
kernel_name += "_large";
} }
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Put the inputs in // Put the inputs in
int cnt = 0; int cnt = 0;
@@ -456,7 +394,8 @@ void Compiled::eval_gpu(
} }
} }
if (!in_strides.empty()) { if (!in_strides.empty()) {
compute_encoder.set_vector_bytes(in_strides, cnt++); compute_encoder->setBytes(
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
} }
compiled_allocate_outputs( compiled_allocate_outputs(
@@ -469,13 +408,14 @@ void Compiled::eval_gpu(
// Put the output shape and strides in // Put the output shape and strides in
if (!contiguous) { if (!contiguous) {
compute_encoder.set_vector_bytes(strides[0], cnt++); compute_encoder->setBytes(
compute_encoder.set_vector_bytes(shape, cnt++); strides[0].data(), strides[0].size() * sizeof(size_t), cnt++);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
} }
// Put the number of dims in if it is dynamic // Put the number of dims in if it is dynamic
if (dynamic) { if (dynamic) {
compute_encoder.set_bytes(ndim, cnt++); compute_encoder->setBytes(&ndim, sizeof(int), cnt++);
} }
// Launch the kernel // Launch the kernel
@@ -484,15 +424,15 @@ void Compiled::eval_gpu(
MTL::Size group_dims( MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
MTL::Size grid_dims = large MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) ? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} else { } else {
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1); size_t rest = outputs[0].size() / (dim0 * dim1);
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1; int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
int pow2; int pow2;
@@ -505,7 +445,7 @@ void Compiled::eval_gpu(
} }
auto group_dims = get_block_dims(dim0, dim1, rest, pow2); auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
} }

View File

@@ -44,24 +44,23 @@ void explicit_gemm_conv_ND_gpu(
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N; kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1); compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder.set_bytes(conv_params, 2); compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
// Launch unfolding kernel // Launch unfolding kernel
size_t tgp_x = std::min(conv_params.C, 64); int tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32); tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
size_t tgp_y = 256 / tgp_x; int tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size( MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
// Reshape weight // Reshape weight
std::vector<int> wt_reshape{implicit_K, implicit_N}; std::vector<int> wt_reshape{implicit_K, implicit_N};
@@ -123,24 +122,23 @@ void explicit_gemm_conv_group_ND_gpu(
<< N; << N;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1); compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder.set_bytes(conv_params, 2); compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
// Launch unfolding kernel // Launch unfolding kernel
size_t tgp_x = std::min(conv_params.C, 64); int tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32); tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
size_t tgp_y = 256 / tgp_x; int tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size( MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
// Transpose kernel weights so that we can slice them by contiguous chunks // Transpose kernel weights so that we can slice them by contiguous chunks
// of channel groups. // of channel groups.
@@ -239,7 +237,7 @@ void slow_conv_2D_gpu(
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1]; size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
@@ -254,8 +252,8 @@ void slow_conv_2D_gpu(
compute_encoder.set_input_array(wt, 1); compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3); compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
} }
void implicit_gemm_conv_2D_gpu( void implicit_gemm_conv_2D_gpu(
@@ -354,7 +352,7 @@ void implicit_gemm_conv_2D_gpu(
wn, wn,
n_channel_specialization, n_channel_specialization,
small_filter); small_filter);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions // Deduce grid launch dimensions
int tile = 1 << swizzle_log; int tile = 1 << swizzle_log;
@@ -370,11 +368,11 @@ void implicit_gemm_conv_2D_gpu(
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode params // Encode params
compute_encoder.set_bytes(conv_params, 3); compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder.set_bytes(gemm_params, 4); compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
// Launch kernel // Launch kernel
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
} }
void implicit_gemm_conv_2D_general_gpu( void implicit_gemm_conv_2D_general_gpu(
@@ -508,7 +506,7 @@ void implicit_gemm_conv_2D_general_gpu(
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = auto kernel =
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn); get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions // Deduce grid launch dimensions
int tile = 1 << swizzle_log; int tile = 1 << swizzle_log;
@@ -525,15 +523,17 @@ void implicit_gemm_conv_2D_general_gpu(
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode params // Encode params
compute_encoder.set_bytes(conv_params, 3); compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder.set_bytes(gemm_params, 4); compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
compute_encoder.set_bytes(jump_params, 5); compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5);
compute_encoder.set_vector_bytes(base_h, 6); compute_encoder->setBytes(
compute_encoder.set_vector_bytes(base_w, 7); base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6);
compute_encoder->setBytes(
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
// Launch kernel // Launch kernel
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
} }
void winograd_conv_2D_gpu( void winograd_conv_2D_gpu(
@@ -622,18 +622,18 @@ void winograd_conv_2D_gpu(
<< bc; << bc;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(wt, 0); compute_encoder.set_input_array(wt, 0);
compute_encoder.set_output_array(filt_wg, 1); compute_encoder.set_output_array(filt_wg, 1);
compute_encoder.set_bytes(C_c, 2); compute_encoder->setBytes(&C_c, sizeof(int), 2);
compute_encoder.set_bytes(O_c, 3); compute_encoder->setBytes(&O_c, sizeof(int), 3);
MTL::Size group_dims = MTL::Size(32, bo, 1); MTL::Size group_dims = MTL::Size(32, bo, 1);
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1); MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
} }
// Do input transform // Do input transform
@@ -650,17 +650,18 @@ void winograd_conv_2D_gpu(
<< bc; << bc;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_padded, 0); compute_encoder.set_input_array(in_padded, 0);
compute_encoder.set_output_array(inp_wg, 1); compute_encoder.set_output_array(inp_wg, 1);
compute_encoder.set_bytes(conv_params_updated, 2); compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
} }
// Do batched gemm // Do batched gemm
@@ -697,17 +698,18 @@ void winograd_conv_2D_gpu(
<< bc; << bc;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(out_wg, 0); compute_encoder.set_input_array(out_wg, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(conv_params_updated, 2); compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
} }
} }

View File

@@ -74,46 +74,44 @@ void copy_gpu_inplace(
}; };
auto [shape, strides_in_, strides_out_] = maybe_collapse(); auto [shape, strides_in_, strides_out_] = maybe_collapse();
int ndim = shape.size(); int ndim = shape.size();
bool large;
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { bool use_2d = out.data_size() > UINT32_MAX;
// Allow for negative strides
large = out.data_size() > INT32_MAX;
} else {
large = out.data_size() > UINT32_MAX;
}
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
int work_per_thread = 1; int work_per_thread = 1;
std::string kernel_name; std::string kernel_name;
switch (ctype) { {
case CopyType::Scalar: std::ostringstream kname;
kernel_name = (large ? "s2" : "s"); switch (ctype) {
break; case CopyType::Scalar:
case CopyType::Vector: kname << (use_2d ? "s2" : "s");
kernel_name = (large ? "v2" : "v"); break;
break; case CopyType::Vector:
case CopyType::General: kname << (use_2d ? "v2" : "v");
kernel_name = "g"; break;
break; case CopyType::General:
case CopyType::GeneralGeneral: kname << "g";
kernel_name = "gg"; break;
break; case CopyType::GeneralGeneral:
} kname << "gg";
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { break;
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kernel_name += std::to_string(shape.size());
} else {
work_per_thread = large ? 4 : 2;
concatenate(kernel_name, "n", std::to_string(work_per_thread));
} }
if (large) { if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
kernel_name += "large"; if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
work_per_thread = 4;
kname << "n4";
}
} }
kname << "_copy";
kname << type_to_name(in) << type_to_name(out);
kernel_name = kname.str();
} }
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = get_copy_kernel(d, kernel_name, in, out); auto kernel = get_copy_kernel(d, kernel_name, in, out);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
bool donate_in = in.data_shared_ptr() == nullptr; bool donate_in = in.data_shared_ptr() == nullptr;
inp_offset *= size_of(in.dtype()); inp_offset *= size_of(in.dtype());
@@ -127,11 +125,11 @@ void copy_gpu_inplace(
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()}; std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()}; std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
if (ndim > 3) { if (ndim > 3) {
compute_encoder.set_vector_bytes(shape, ndim, 2); set_vector_bytes(compute_encoder, shape, ndim, 2);
} }
compute_encoder.set_vector_bytes(strides_in, ndim, 3); set_vector_bytes(compute_encoder, strides_in, ndim, 3);
if (ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::GeneralGeneral) {
compute_encoder.set_vector_bytes(strides_out, ndim, 4); set_vector_bytes(compute_encoder, strides_out, ndim, 4);
} }
int dim0 = ndim > 0 ? shape[ndim - 1] : 1; int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
@@ -143,7 +141,7 @@ void copy_gpu_inplace(
int rest = data_size / (dim0 * dim1); int rest = data_size / (dim0 * dim1);
if (ndim > MAX_COPY_SPECIALIZED_DIMS) { if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder.set_bytes(ndim, 5); compute_encoder->setBytes(&ndim, sizeof(int), 5);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} }
@@ -154,16 +152,16 @@ void copy_gpu_inplace(
auto group_dims = get_block_dims(dim0, dim1, rest); auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} else { } else {
size_t nthreads = out.data_size(); size_t nthreads = out.data_size();
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
thread_group_size = nthreads; thread_group_size = nthreads;
} }
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
} }
@@ -195,13 +193,13 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
return; return;
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
bool large = out.data_size() > UINT32_MAX; bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" + std::string kernel_name = std::string(use_2d ? "s2" : "s") + "_copy" +
type_to_name(val) + type_to_name(out); type_to_name(val) + type_to_name(out);
auto kernel = get_copy_kernel(d, kernel_name, val, out); auto kernel = get_copy_kernel(d, kernel_name, val, out);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(val, 0); compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
@@ -212,9 +210,9 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
thread_group_size = nthreads; thread_group_size = nthreads;
} }
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -43,7 +43,7 @@ void CustomKernel::eval_gpu(
d.get_library(lib_name, [this] { return metal::utils() + source_; }); d.get_library(lib_name, [this] { return metal::utils() + source_; });
auto kernel = d.get_kernel(name_, lib); auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
int index = 0; int index = 0;
for (int i = 0; i < checked_inputs.size(); i++) { for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i]; const array& in = checked_inputs[i];
@@ -53,15 +53,15 @@ void CustomKernel::eval_gpu(
if (in.ndim() > 0) { if (in.ndim() > 0) {
int ndim = in.ndim(); int ndim = in.ndim();
if (shape_info.shape) { if (shape_info.shape) {
compute_encoder.set_vector_bytes(in.shape(), ndim, index); set_vector_bytes(compute_encoder, in.shape(), ndim, index);
index++; index++;
} }
if (shape_info.strides) { if (shape_info.strides) {
compute_encoder.set_vector_bytes(in.strides(), ndim, index); set_vector_bytes(compute_encoder, in.strides(), ndim, index);
index++; index++;
} }
if (shape_info.ndim) { if (shape_info.ndim) {
compute_encoder.set_bytes(ndim, index); compute_encoder->setBytes(&ndim, sizeof(int), index);
index++; index++;
} }
} }
@@ -72,11 +72,10 @@ void CustomKernel::eval_gpu(
} }
const auto [tx, ty, tz] = threadgroup_; const auto [tx, ty, tz] = threadgroup_;
MTL::Size group_dims = MTL::Size(tx, ty, tz);
const auto [gx, gy, gz] = grid_; const auto [gx, gy, gz] = grid_;
MTL::Size group_dims =
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
MTL::Size grid_dims = MTL::Size(gx, gy, gz); MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }

View File

@@ -23,18 +23,14 @@ constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr const char* default_mtllib_path = METAL_PATH; constexpr const char* default_mtllib_path = METAL_PATH;
auto get_metal_version() { constexpr auto get_metal_version() {
auto get_metal_version_ = []() { #if (MLX_METAL_VERSION >= 320)
if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) { return MTL::LanguageVersion3_2;
return MTL::LanguageVersion3_2; #elif (MLX_METAL_VERSION >= 310)
} else if (__builtin_available(macOS 14, iOS 17, tvOS 17, visionOS 1, *)) { return MTL::LanguageVersion3_1;
return MTL::LanguageVersion3_1; #else
} else { return MTL::LanguageVersion3_0;
return MTL::LanguageVersion3_0; #endif
}
};
static auto metal_version_ = get_metal_version_();
return metal_version_;
} }
auto load_device() { auto load_device() {
@@ -175,14 +171,14 @@ void CommandEncoder::maybeInsertBarrier() {
next_outputs_.clear(); next_outputs_.clear();
} }
void CommandEncoder::dispatch_threadgroups( void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims, MTL::Size grid_dims,
MTL::Size group_dims) { MTL::Size group_dims) {
maybeInsertBarrier(); maybeInsertBarrier();
enc_->dispatchThreadgroups(grid_dims, group_dims); enc_->dispatchThreadgroups(grid_dims, group_dims);
} }
void CommandEncoder::dispatch_threads( void CommandEncoder::dispatchThreads(
MTL::Size grid_dims, MTL::Size grid_dims,
MTL::Size group_dims) { MTL::Size group_dims) {
maybeInsertBarrier(); maybeInsertBarrier();
@@ -302,7 +298,7 @@ void Device::end_encoding(int index) {
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) { if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
// If we've already waited on a fence, don't wait on it again. // If we've already waited on a fence, don't wait on it again.
if (waiting_on.find(it->second) == waiting_on.end()) { if (waiting_on.find(it->second) == waiting_on.end()) {
enc.wait_for_fence(it->second->fence); enc->waitForFence(it->second->fence);
waiting_on.insert(it->second); waiting_on.insert(it->second);
} }
} }
@@ -311,7 +307,7 @@ void Device::end_encoding(int index) {
stream.outputs[out] = stream.fence; stream.outputs[out] = stream.fence;
} }
} }
enc.update_fence(stream.fence->fence); enc->updateFence(stream.fence->fence);
stream.buffer->addCompletedHandler( stream.buffer->addCompletedHandler(
[&stream, [&stream,
waiting_on = std::move(waiting_on), waiting_on = std::move(waiting_on),

View File

@@ -58,43 +58,16 @@ struct CommandEncoder {
CommandEncoder& enc; CommandEncoder& enc;
}; };
MTL::ComputeCommandEncoder* operator->() {
return enc_;
}
void set_input_array(const array& a, int idx, int64_t offset = 0); void set_input_array(const array& a, int idx, int64_t offset = 0);
void set_output_array(array& a, int idx, int64_t offset = 0); void set_output_array(array& a, int idx, int64_t offset = 0);
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims); void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims); void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
void maybeInsertBarrier(); void maybeInsertBarrier();
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
enc_->setComputePipelineState(kernel);
}
void wait_for_fence(MTL::Fence* fence) {
enc_->waitForFence(fence);
}
void update_fence(MTL::Fence* fence) {
enc_->updateFence(fence);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}
template <typename T>
void set_bytes(const T* v, int n, int idx) {
return enc_->setBytes(v, n * sizeof(T), idx);
}
template <typename T>
void set_bytes(const T& v, int idx) {
return enc_->setBytes(&v, sizeof(T), idx);
}
ConcurrentContext start_concurrent() { ConcurrentContext start_concurrent() {
return ConcurrentContext(*this); return ConcurrentContext(*this);
} }

View File

@@ -699,7 +699,7 @@ void fft_op(
auto kernel = auto kernel =
get_fft_kernel(d, base_name, hash_name, func_consts, template_def); get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_contiguous, 0); compute_encoder.set_input_array(in_contiguous, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
@@ -711,9 +711,9 @@ void fft_op(
compute_encoder.set_input_array(w_q, 2); // w_q compute_encoder.set_input_array(w_q, 2); // w_q
compute_encoder.set_input_array(w_k, 3); // w_k compute_encoder.set_input_array(w_k, 3); // w_k
compute_encoder.set_bytes(n, 4); compute_encoder->setBytes(&n, sizeof(int), 4);
compute_encoder.set_bytes(plan.bluestein_n, 5); compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
compute_encoder.set_bytes(total_batch_size, 6); compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
} else if (plan.rader_n > 1) { } else if (plan.rader_n > 1) {
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s); auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
copies.push_back(b_q); copies.push_back(b_q);
@@ -723,22 +723,22 @@ void fft_op(
compute_encoder.set_input_array(b_q, 2); compute_encoder.set_input_array(b_q, 2);
compute_encoder.set_input_array(g_q, 3); compute_encoder.set_input_array(g_q, 3);
compute_encoder.set_input_array(g_minus_q, 4); compute_encoder.set_input_array(g_minus_q, 4);
compute_encoder.set_bytes(n, 5); compute_encoder->setBytes(&n, sizeof(int), 5);
compute_encoder.set_bytes(total_batch_size, 6); compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
compute_encoder.set_bytes(plan.rader_n, 7); compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
} else if (four_step_params.required) { } else if (four_step_params.required) {
compute_encoder.set_bytes(four_step_params.n1, 2); compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
compute_encoder.set_bytes(four_step_params.n2, 3); compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
compute_encoder.set_bytes(total_batch_size, 4); compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
} else { } else {
compute_encoder.set_bytes(n, 2); compute_encoder->setBytes(&n, sizeof(int), 2);
compute_encoder.set_bytes(total_batch_size, 3); compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
} }
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft); auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
auto grid_dims = auto grid_dims =
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft); MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);

View File

@@ -137,14 +137,14 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup()); assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(scale, 2); compute_encoder->setBytes(&scale, sizeof(float), 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1); MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);
}; };
if (m > 1) { if (m > 1) {

View File

@@ -53,31 +53,27 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
int idx_ndim = nidx ? inputs[1].ndim() : 0; int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim(); size_t ndim = src.ndim();
bool large_index = nidx && inputs[1].size() > UINT32_MAX; std::string lib_name;
bool large_src = src.size() > UINT32_MAX; std::string kernel_name;
bool large_out = out.size() > UINT32_MAX;
bool large = large_index || large_src || large_out;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string kernel_name = fmt::format( {
"gather{0}{1}_{2}_{3}_{4}", std::ostringstream kname;
type_to_name(out), kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
idx_type_name, << "_" << idx_ndim;
nidx, lib_name = kname.str();
idx_ndim, kernel_name = lib_name;
large ? "size_t" : "uint"); }
std::string lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils(); std::ostringstream kernel_source;
kernel_source += metal::gather(); kernel_source << metal::utils() << metal::gather();
std::string out_type_str = get_type_string(out.dtype()); std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str = std::string idx_type_str =
nidx ? get_type_string(inputs[1].dtype()) : "bool"; nidx ? get_type_string(inputs[1].dtype()) : "bool";
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
// Index dimension specializations // Index dimension specializations
kernel_source += fmt::format( kernel_source << fmt::format(
gather_kernels, gather_kernels,
type_to_name(out) + idx_type_name, type_to_name(out) + idx_type_name,
out_type_str, out_type_str,
@@ -85,14 +81,13 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
nidx, nidx,
idx_args, idx_args,
idx_arr, idx_arr,
idx_ndim, idx_ndim);
large ? "size_t" : "uint"); return kernel_source.str();
return kernel_source;
}); });
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
size_t slice_size = 1; size_t slice_size = 1;
for (auto s : slice_sizes_) { for (auto s : slice_sizes_) {
@@ -136,20 +131,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
// Set source info // Set source info
compute_encoder.set_vector_bytes(src.shape(), 2); set_vector_bytes(compute_encoder, src.shape(), 2);
compute_encoder.set_vector_bytes(src.strides(), 3); set_vector_bytes(compute_encoder, src.strides(), 3);
compute_encoder.set_bytes(ndim, 4); compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
compute_encoder.set_vector_bytes(slice_sizes_, 5); set_vector_bytes(compute_encoder, slice_sizes_, 5);
compute_encoder.set_vector_bytes(axes_, 6); set_vector_bytes(compute_encoder, axes_, 6);
// Set index info // Set index info
// //
// We don't need to check for empty idx_shapes because gather has a // We don't need to check for empty idx_shapes because gather has a
// idx_ndim == 0 specialization // idx_ndim == 0 specialization
compute_encoder.set_vector_bytes(idx_shapes, 7); set_vector_bytes(compute_encoder, idx_shapes, 7);
compute_encoder.set_vector_bytes(idx_strides, 8); set_vector_bytes(compute_encoder, idx_strides, 8);
compute_encoder.set_vector_bytes(idx_contigs, 9); set_vector_bytes(compute_encoder, idx_contigs, 9);
compute_encoder.set_bytes(idx_ndim, 10); compute_encoder->setBytes(&idx_ndim, sizeof(int), 10);
// Set index buffers // Set index buffers
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
@@ -157,7 +152,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
// Launch grid // Launch grid
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) { void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -214,6 +209,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
nwork = 32; nwork = 32;
} }
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string op_name; std::string op_name;
switch (reduce_type_) { switch (reduce_type_) {
@@ -234,24 +231,18 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
break; break;
} }
auto upd_contig = upd.flags().row_contiguous; auto upd_contig = upd.flags().row_contiguous;
bool large_out = out.size() > UINT32_MAX; {
bool large_idx = nidx && (inputs[1].size() > UINT32_MAX); std::ostringstream kname;
bool large_upd = upd.size() > UINT32_MAX; kname << "scatter" << type_to_name(out) << idx_type_name;
bool large = large_out || large_idx || large_upd; kname << "_" << op_name << "_" << nidx << "_"
std::string kernel_name = fmt::format( << (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork;
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}", lib_name = kname.str();
type_to_name(out), kernel_name = kname.str();
idx_type_name, }
op_name,
nidx,
upd_contig ? "updc_true" : "updc_false",
nwork,
large ? "size_t" : "uint");
std::string lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils(); std::ostringstream kernel_source;
concatenate(kernel_source, metal::reduce_utils(), metal::scatter()); kernel_source << metal::utils() << metal::reduce_utils()
<< metal::scatter();
std::string out_type_str = get_type_string(out.dtype()); std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str = std::string idx_type_str =
@@ -279,7 +270,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
kernel_source += fmt::format( kernel_source << fmt::format(
scatter_kernels, scatter_kernels,
type_to_name(out) + idx_type_name + "_" + op_name, type_to_name(out) + idx_type_name + "_" + op_name,
out_type_str, out_type_str,
@@ -289,9 +280,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_args, idx_args,
idx_arr, idx_arr,
upd_contig, upd_contig,
nwork, nwork);
large ? "size_t" : "uint"); return kernel_source.str();
return kernel_source;
}); });
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@@ -299,7 +289,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
size_t nthreads = upd.size(); size_t nthreads = upd.size();
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Set all the buffers // Set all the buffers
compute_encoder.set_input_array(upd, 1); compute_encoder.set_input_array(upd, 1);
@@ -333,14 +323,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Need placeholders so Metal doesn't compalain // Need placeholders so Metal doesn't compalain
int shape_ = 0; int shape_ = 0;
size_t stride_ = 0; size_t stride_ = 0;
compute_encoder.set_bytes(shape_, 3); compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder.set_bytes(stride_, 4); compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else { } else {
compute_encoder.set_vector_bytes(upd.shape(), 3); set_vector_bytes(compute_encoder, upd.shape(), 3);
compute_encoder.set_vector_bytes(upd.strides(), 4); set_vector_bytes(compute_encoder, upd.strides(), 4);
} }
compute_encoder.set_bytes(upd_ndim, 5); compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder.set_bytes(upd_size, 6); compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
// Set output info // Set output info
size_t out_ndim = out.ndim(); size_t out_ndim = out.ndim();
@@ -348,14 +338,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Need placeholders so Metal doesn't compalain // Need placeholders so Metal doesn't compalain
int shape_ = 0; int shape_ = 0;
size_t stride_ = 0; size_t stride_ = 0;
compute_encoder.set_bytes(shape_, 7); compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder.set_bytes(stride_, 8); compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else { } else {
compute_encoder.set_vector_bytes(out.shape(), 7); set_vector_bytes(compute_encoder, out.shape(), 7);
compute_encoder.set_vector_bytes(out.strides(), 8); set_vector_bytes(compute_encoder, out.strides(), 8);
} }
compute_encoder.set_bytes(out_ndim, 9); compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder.set_vector_bytes(axes_, 10); compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
// Set index info // Set index info
if (idx_ndim == 0) { if (idx_ndim == 0) {
@@ -365,11 +355,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_strides.push_back(0); idx_strides.push_back(0);
idx_contigs.push_back(false); idx_contigs.push_back(false);
} }
compute_encoder.set_vector_bytes(idx_shapes, 11); set_vector_bytes(compute_encoder, idx_shapes, 11);
compute_encoder.set_vector_bytes(idx_strides, 12); set_vector_bytes(compute_encoder, idx_strides, 12);
compute_encoder.set_vector_bytes(idx_contigs, 13); set_vector_bytes(compute_encoder, idx_contigs, 13);
compute_encoder.set_bytes(idx_ndim, 14); compute_encoder->setBytes(&idx_ndim, sizeof(int), 14);
compute_encoder.set_bytes(idx_size, 15); compute_encoder->setBytes(&idx_size, sizeof(size_t), 15);
// Set index buffers // Set index buffers
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
@@ -385,7 +375,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads"); throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
} }
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1); MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
constexpr std::string_view gather_kernels = R"( constexpr std::string_view gather_kernels = R"(
[[kernel]] void gather{0}_{3}_{6}_{7}( [[kernel]] void gather{0}_{3}_{6}(
const device {1}* src [[buffer(0)]], const device {1}* src [[buffer(0)]],
device {1}* out [[buffer(1)]], device {1}* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]], const constant int* src_shape [[buffer(2)]],
@@ -19,7 +19,7 @@ constexpr std::string_view gather_kernels = R"(
Indices<{2}, {3}> idxs{{ Indices<{2}, {3}> idxs{{
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}}; {{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return gather_impl<{1}, {2}, {3}, {6}, {7}>( return gather_impl<{1}, {2}, {3}, {6}>(
src, src,
out, out,
src_shape, src_shape,
@@ -34,7 +34,7 @@ constexpr std::string_view gather_kernels = R"(
)"; )";
constexpr std::string_view scatter_kernels = R"( constexpr std::string_view scatter_kernels = R"(
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}( [[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}(
const device {1}* updates [[buffer(1)]], const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]], device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]], const constant int* upd_shape [[buffer(3)]],
@@ -54,7 +54,7 @@ constexpr std::string_view scatter_kernels = R"(
uint2 gid [[thread_position_in_grid]]) {{ uint2 gid [[thread_position_in_grid]]) {{
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}}; Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>( return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>(
updates, updates,
out, out,
upd_shape, upd_shape,

View File

@@ -1,4 +1,5 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h" #include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/gemv_masked.h" #include "mlx/backend/metal/jit/gemv_masked.h"
@@ -45,27 +46,25 @@ MTL::ComputePipelineState* get_unary_kernel(
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
auto in_t = get_type_string(in_type); auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type); auto out_t = get_type_string(out_type);
std::string kernel_source = metal::utils(); std::ostringstream kernel_source;
concatenate(kernel_source, metal::unary_ops(), metal::unary()); kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
kernel_source += kernel_source << get_template_definition(
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op); "v_" + lib_name, "unary_v", in_t, out_t, op);
kernel_source += kernel_source << get_template_definition(
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op); "v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source += get_template_definition( kernel_source << get_template_definition(
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint"); "gn4_" + lib_name, "unary_g", in_t, out_t, op, 4);
kernel_source += get_template_definition( return kernel_source.str();
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
void append_binary_kernels( void add_binary_kernels(
const std::string lib_name, const std::string lib_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const std::string op, const std::string op,
std::string& kernel_source) { std::ostringstream& kernel_source) {
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{ const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
{"ss", "binary_ss"}, {"ss", "binary_ss"},
{"vs", "binary_vs"}, {"vs", "binary_vs"},
@@ -75,24 +74,26 @@ void append_binary_kernels(
{"sv2", "binary_sv2"}, {"sv2", "binary_sv2"},
{"vv2", "binary_vv2"}, {"vv2", "binary_vv2"},
{"g1", "binary_g_nd1"}, {"g1", "binary_g_nd1"},
{"g2large", "binary_g_nd2"}, {"g2", "binary_g_nd2"},
{"g3large", "binary_g_nd3"}, {"g3", "binary_g_nd3"},
}}; }};
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
for (auto& [name, func] : kernel_types) { for (auto& [name, func] : kernel_types) {
kernel_source += std::string template_def;
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op); template_def = get_template_definition(
name + "_" + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
kernel_source << template_def;
} }
kernel_source += get_template_definition( kernel_source << get_template_definition(
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint"); "gn4_" + lib_name,
kernel_source += get_template_definition( "binary_g",
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint"); get_type_string(in_type),
kernel_source += get_template_definition( get_type_string(out_type),
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint"); op,
kernel_source += get_template_definition( 4);
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
} }
MTL::ComputePipelineState* get_binary_kernel( MTL::ComputePipelineState* get_binary_kernel(
@@ -103,11 +104,10 @@ MTL::ComputePipelineState* get_binary_kernel(
const std::string op) { const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source; std::ostringstream kernel_source;
kernel_source = metal::utils(); kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
concatenate(kernel_source, metal::binary_ops(), metal::binary()); add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source); return kernel_source.str();
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -120,10 +120,11 @@ MTL::ComputePipelineState* get_binary_two_kernel(
const std::string op) { const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils(); std::ostringstream kernel_source;
concatenate(kernel_source, metal::binary_ops(), metal::binary_two()); kernel_source << metal::utils() << metal::binary_ops()
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source); << metal::binary_two();
return kernel_source; add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source.str();
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -135,29 +136,24 @@ MTL::ComputePipelineState* get_ternary_kernel(
const std::string op) { const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
auto t_str = get_type_string(type); std::ostringstream kernel_source;
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{ const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"}, {"v", "ternary_v"},
{"v2", "ternary_v2"}, {"v2", "ternary_v2"},
{"g1", "ternary_g_nd1"}, {"g1", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"}, {"g2", "ternary_g_nd2"},
{"g3large", "ternary_g_nd3"}, {"g3", "ternary_g_nd3"},
}}; }};
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
for (auto& [name, func] : kernel_types) { for (auto& [name, func] : kernel_types) {
kernel_source += std::string template_def;
get_template_definition(name + "_" + lib_name, func, t_str, op); template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
kernel_source << template_def;
} }
kernel_source += get_template_definition( kernel_source << get_template_definition(
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint"); "gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
kernel_source += get_template_definition( return kernel_source.str();
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint");
kernel_source += get_template_definition(
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -169,43 +165,31 @@ MTL::ComputePipelineState* get_copy_kernel(
const array& out) { const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils(); std::ostringstream kernel_source;
kernel_source += metal::copy();
auto in_type = get_type_string(in.dtype()); auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype()); auto out_type = get_type_string(out.dtype());
kernel_source += kernel_source << metal::utils() << metal::copy()
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type); << get_template_definition(
kernel_source += "s_" + lib_name, "copy_s", in_type, out_type)
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type); << get_template_definition(
kernel_source += get_template_definition( "v_" + lib_name, "copy_v", in_type, out_type)
"g1_" + lib_name, "copy_g_nd1", in_type, out_type); << get_template_definition(
kernel_source += get_template_definition( "g1_" + lib_name, "copy_g_nd1", in_type, out_type)
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int"); << get_template_definition(
kernel_source += get_template_definition( "g2_" + lib_name, "copy_g_nd2", in_type, out_type)
"g3_" + lib_name, "copy_g_nd3", in_type, out_type, "int"); << get_template_definition(
kernel_source += get_template_definition( "g3_" + lib_name, "copy_g_nd3", in_type, out_type)
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int"); << get_template_definition(
kernel_source += get_template_definition( "gn4_" + lib_name, "copy_g", in_type, out_type, 4)
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type); << get_template_definition(
kernel_source += get_template_definition( "gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int"); << get_template_definition(
kernel_source += get_template_definition( "gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int"); << get_template_definition(
kernel_source += get_template_definition( "gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int"); << get_template_definition(
kernel_source += get_template_definition( "ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type); return kernel_source.str();
kernel_source += get_template_definition(
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
kernel_source += get_template_definition(
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
kernel_source += get_template_definition(
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
kernel_source += get_template_definition(
"gg3large_" + lib_name, "copy_gg_nd3", in_type, out_type);
kernel_source += get_template_definition(
"ggn4large_" + lib_name, "copy_gg", in_type, out_type, 4);
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -337,17 +321,17 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string& func_name, const std::string& func_name,
const std::string& op_name, const std::string& op_name,
const Dtype& out_type) { const array& out) {
auto lib = d.get_library(kernel_name, [&]() { auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
std::string op_type = op_name; std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]); op_type[0] = std::toupper(op_name[0]);
auto out_t = get_type_string(out_type); auto out_type = get_type_string(out.dtype());
std::string op = op_type + "<" + out_t + ">"; std::string op = op_type + "<" + out_type + ">";
std::string kernel_source = metal::utils(); kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
kernel_source += metal::reduce_utils(); kernel_source << get_template_definition(
kernel_source += metal::reduce(); kernel_name, func_name, out_type, op);
kernel_source += get_template_definition(kernel_name, func_name, out_t, op); return kernel_source.str();
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -357,31 +341,30 @@ MTL::ComputePipelineState* get_reduce_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string& func_name, const std::string& func_name,
const std::string& op_name, const std::string& op_name,
const Dtype& in_type, const array& in,
const Dtype& out_type, const array& out,
const std::string& idx_t,
int ndim /* = -1 */, int ndim /* = -1 */,
int bm /* = -1 */, int bm /* = -1 */,
int bn /* = -1 */) { int bn /* = -1 */) {
auto lib = d.get_library(kernel_name, [&]() { auto lib = d.get_library(kernel_name, [&]() {
std::string op_type = op_name; std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]); op_type[0] = std::toupper(op_name[0]);
auto in_t = get_type_string(in_type); std::ostringstream kernel_source;
auto out_t = get_type_string(out_type); auto in_type = get_type_string(in.dtype());
std::string op = op_type + "<" + out_t + ">"; auto out_type = get_type_string(out.dtype());
std::string kernel_source = metal::utils(); std::string op = op_type + "<" + out_type + ">";
concatenate(kernel_source, metal::reduce_utils(), metal::reduce()); kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
if (bm >= 0) { if (bm >= 0) {
kernel_source += get_template_definition( kernel_source << get_template_definition(
kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn); kernel_name, func_name, in_type, out_type, op, ndim, bm, bn);
} else if (ndim >= 0) { } else if (ndim >= 0) {
kernel_source += get_template_definition( kernel_source << get_template_definition(
kernel_name, func_name, in_t, out_t, op, idx_t, ndim); kernel_name, func_name, in_type, out_type, op, ndim);
} else { } else {
kernel_source += get_template_definition( kernel_source << get_template_definition(
kernel_name, func_name, in_t, out_t, op, idx_t); kernel_name, func_name, in_type, out_type, op);
} }
return kernel_source; return kernel_source.str();
}); });
auto st = d.get_kernel(kernel_name, lib); auto st = d.get_kernel(kernel_name, lib);
return st; return st;

View File

@@ -81,16 +81,15 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string& func_name, const std::string& func_name,
const std::string& op_name, const std::string& op_name,
const Dtype& out_type); const array& out);
MTL::ComputePipelineState* get_reduce_kernel( MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const std::string& func_name, const std::string& func_name,
const std::string& op_name, const std::string& op_name,
const Dtype& in_type, const array& in,
const Dtype& out_type, const array& out,
const std::string& idx_t,
int ndim = -1, int ndim = -1,
int bm = -1, int bm = -1,
int bn = -1); int bn = -1);

View File

@@ -1,27 +1,13 @@
set(BASE_HEADERS set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h)
metal_3_1/bf16.h
metal_3_0/bf16.h
bf16_math.h
complex.h
defines.h
expm1f.h
utils.h)
function(build_kernel_base TARGET SRCFILE DEPS) function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math) set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif() endif()
if(MLX_METAL_VERSION GREATER_EQUAL 310)
set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
else()
set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_0)
endif()
add_custom_command( add_custom_command(
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE} COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -I${VERSION_INCLUDES} -o ${TARGET}.air -I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS} DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air OUTPUT ${TARGET}.air
COMMENT "Building ${TARGET}.air" COMMENT "Building ${TARGET}.air"
@@ -44,7 +30,9 @@ build_kernel(layer_norm)
build_kernel(random) build_kernel(random)
build_kernel(rms_norm) build_kernel(rms_norm)
build_kernel(rope) build_kernel(rope)
build_kernel(scaled_dot_product_attention sdpa_vector.h) build_kernel(
scaled_dot_product_attention scaled_dot_product_attention_params.h
sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)
set(STEEL_HEADERS set(STEEL_HEADERS
steel/defines.h steel/defines.h
@@ -66,24 +54,6 @@ set(STEEL_HEADERS
steel/utils/type_traits.h steel/utils/type_traits.h
steel/utils/integral_constant.h) steel/utils/integral_constant.h)
set(STEEL_ATTN_HEADERS
steel/defines.h
steel/utils.h
steel/gemm/gemm.h
steel/gemm/mma.h
steel/gemm/loader.h
steel/gemm/transforms.h
steel/utils/type_traits.h
steel/utils/integral_constant.h
steel/attn/attn.h
steel/attn/loader.h
steel/attn/mma.h
steel/attn/params.h
steel/attn/transforms.h
steel/attn/kernels/steel_attention.h)
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
if(NOT MLX_METAL_JIT) if(NOT MLX_METAL_JIT)
build_kernel(arange arange.h) build_kernel(arange arange.h)
build_kernel(binary binary.h binary_ops.h) build_kernel(binary binary.h binary_ops.h)

View File

@@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/arange.h" #include "mlx/backend/metal/kernels/arange.h"
#define instantiate_arange(tname, type) \ #define instantiate_arange(tname, type) \

View File

@@ -6,6 +6,12 @@
using namespace metal; using namespace metal;
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
typedef bfloat bfloat16_t;
#else
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Helpers // Helpers
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
@@ -305,10 +311,7 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
} // namespace metal } // namespace metal
#pragma METAL internals : disable #pragma METAL internals : disable
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
return x.bits_;
}
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { #endif
return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat());
} #include "mlx/backend/metal/kernels/bf16_math.h"

View File

@@ -2,6 +2,8 @@
#pragma once #pragma once
#include "mlx/backend/metal/kernels/bf16.h"
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Metal math for bfloat16 // Metal math for bfloat16
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -367,6 +369,18 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \ return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
} }
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
#else
#define bfloat16_to_uint16(x) x.bits_
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
#endif
namespace metal { namespace metal {
instantiate_metal_simd_comm_funcs( instantiate_metal_simd_comm_funcs(

View File

@@ -77,12 +77,12 @@ template <typename T, typename U, typename Op>
constant const size_t& a_stride, constant const size_t& a_stride,
constant const size_t& b_stride, constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride); auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride); auto b_idx = elem_to_loc_1(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]); c[index] = Op()(a[a_idx], b[b_idx]);
} }
template <typename T, typename U, typename Op, typename IdxT = size_t> template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd2( [[kernel]] void binary_g_nd2(
device const T* a, device const T* a,
device const T* b, device const T* b,
@@ -91,13 +91,13 @@ template <typename T, typename U, typename Op, typename IdxT = size_t>
constant const size_t b_strides[2], constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides); auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides); auto b_idx = elem_to_loc_2(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]); c[out_idx] = Op()(a[a_idx], b[b_idx]);
} }
template <typename T, typename U, typename Op, typename IdxT = size_t> template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd3( [[kernel]] void binary_g_nd3(
device const T* a, device const T* a,
device const T* b, device const T* b,
@@ -106,18 +106,14 @@ template <typename T, typename U, typename Op, typename IdxT = size_t>
constant const size_t b_strides[3], constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides); auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides); auto b_idx = elem_to_loc_3(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]); c[out_idx] = Op()(a[a_idx], b[b_idx]);
} }
template < template <typename T, typename U, typename Op, int N = 1>
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
[[kernel]] void binary_g( [[kernel]] void binary_g(
device const T* a, device const T* a,
device const T* b, device const T* b,
@@ -128,12 +124,13 @@ template <
constant const int& ndim, constant const int& ndim,
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<size_t, IdxT>( auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1]; auto xshape = shape[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); size_t out_idx =
IdxT a_xstride = a_strides[ndim - 1]; N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
IdxT b_xstride = b_strides[ndim - 1]; auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
c[out_idx++] = Op()(a[idx.x], b[idx.y]); c[out_idx++] = Op()(a[idx.x], b[idx.y]);
idx.x += a_xstride; idx.x += a_xstride;

View File

@@ -9,21 +9,18 @@
#include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h" #include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \ #define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \ instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \ instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \ instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_integer(op) \ #define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \

View File

@@ -99,14 +99,14 @@ template <typename T, typename U, typename Op>
constant const size_t& a_stride, constant const size_t& a_stride,
constant const size_t& b_stride, constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride); auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride); auto b_idx = elem_to_loc_1(index, b_stride);
auto out = Op()(a[a_idx], b[b_idx]); auto out = Op()(a[a_idx], b[b_idx]);
c[index] = out[0]; c[index] = out[0];
d[index] = out[1]; d[index] = out[1];
} }
template <typename T, typename U, typename Op, typename IdxT = size_t> template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd2( [[kernel]] void binary_g_nd2(
device const T* a, device const T* a,
device const T* b, device const T* b,
@@ -116,15 +116,15 @@ template <typename T, typename U, typename Op, typename IdxT = size_t>
constant const size_t b_strides[2], constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides); auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides); auto b_idx = elem_to_loc_2(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto out = Op()(a[a_idx], b[b_idx]); auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0]; c[out_idx] = out[0];
d[out_idx] = out[1]; d[out_idx] = out[1];
} }
template <typename T, typename U, typename Op, typename IdxT = size_t> template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd3( [[kernel]] void binary_g_nd3(
device const T* a, device const T* a,
device const T* b, device const T* b,
@@ -134,20 +134,16 @@ template <typename T, typename U, typename Op, typename IdxT = size_t>
constant const size_t b_strides[3], constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides); auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides); auto b_idx = elem_to_loc_3(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto out = Op()(a[a_idx], b[b_idx]); auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0]; c[out_idx] = out[0];
d[out_idx] = out[1]; d[out_idx] = out[1];
} }
template < template <typename T, typename U, typename Op, int N = 1>
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
[[kernel]] void binary_g( [[kernel]] void binary_g(
device const T* a, device const T* a,
device const T* b, device const T* b,
@@ -159,12 +155,13 @@ template <
constant const int& ndim, constant const int& ndim,
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<size_t, IdxT>( auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1]; auto xshape = shape[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); size_t out_idx =
IdxT a_xstride = a_strides[ndim - 1]; N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
IdxT b_xstride = b_strides[ndim - 1]; auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
auto out = Op()(a[idx.x], b[idx.y]); auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0]; c[out_idx] = out[0];

View File

@@ -7,21 +7,18 @@
#include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h" #include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \ #define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \ instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \ instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \ instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_float(op) \ #define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float16, half, half) \

View File

@@ -4,8 +4,8 @@
#include <metal_simdgroup_matrix> #include <metal_simdgroup_matrix>
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/utils.h"
#define MLX_MTL_CONST static constant constexpr const #define MLX_MTL_CONST static constant constexpr const

View File

@@ -42,36 +42,36 @@ template <typename T, typename U>
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]], constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride); auto src_idx = elem_to_loc_1(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]); dst[index] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U, typename IdxT = int64_t> template <typename T, typename U>
[[kernel]] void copy_g_nd2( [[kernel]] void copy_g_nd2(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides); auto src_idx = elem_to_loc_2(index, src_strides);
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y; int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U, typename IdxT = int64_t> template <typename T, typename U>
[[kernel]] void copy_g_nd3( [[kernel]] void copy_g_nd3(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides); auto src_idx = elem_to_loc_3(index, src_strides);
IdxT dst_idx = int64_t dst_idx =
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z); index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U, int N = 1, typename IdxT = int64_t> template <typename T, typename U, int N = 1>
[[kernel]] void copy_g( [[kernel]] void copy_g(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
@@ -80,16 +80,17 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
constant const int& ndim [[buffer(5)]], constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc<int64_t, IdxT>( auto src_idx = elem_to_loc(
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim); {N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
if (N == 1) { if (N == 1) {
IdxT dst_idx = int64_t dst_idx =
index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
return; return;
} }
auto xshape = src_shape[ndim - 1]; auto xshape = src_shape[ndim - 1];
IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); int64_t dst_idx =
N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
auto src_xstride = src_strides[ndim - 1]; auto src_xstride = src_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[dst_idx + i] = static_cast<U>(src[src_idx]); dst[dst_idx + i] = static_cast<U>(src[src_idx]);
@@ -104,36 +105,36 @@ template <typename T, typename U>
constant const int64_t& src_stride [[buffer(3)]], constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]], constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride); auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1<int64_t, int>(index, dst_stride); auto dst_idx = elem_to_loc_1(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U, typename IdxT = int64_t> template <typename T, typename U>
[[kernel]] void copy_gg_nd2( [[kernel]] void copy_gg_nd2(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]], constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) { uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides); auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2<int64_t, IdxT>(index, dst_strides); auto dst_idx = elem_to_loc_2(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U, typename IdxT = int64_t> template <typename T, typename U>
[[kernel]] void copy_gg_nd3( [[kernel]] void copy_gg_nd3(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]], constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) { uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides); auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3<int64_t, IdxT>(index, dst_strides); auto dst_idx = elem_to_loc_3(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U, int N = 1, typename IdxT = int64_t> template <typename T, typename U, int N = 1>
[[kernel]] void copy_gg( [[kernel]] void copy_gg(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
@@ -142,7 +143,7 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
constant const int64_t* dst_strides [[buffer(4)]], constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]], constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) { uint3 index [[thread_position_in_grid]]) {
auto idx = elem_to_loc_2_nd<int64_t, IdxT>( auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z}, {N * index.x, index.y, index.z},
src_shape, src_shape,
src_strides, src_strides,
@@ -152,8 +153,8 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
dst[idx.y] = static_cast<U>(src[idx.x]); dst[idx.y] = static_cast<U>(src[idx.x]);
return; return;
} }
IdxT src_xstride = src_strides[ndim - 1]; auto src_xstride = src_strides[ndim - 1];
IdxT dst_xstride = dst_strides[ndim - 1]; auto dst_xstride = dst_strides[ndim - 1];
auto xshape = src_shape[ndim - 1]; auto xshape = src_shape[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[idx.y] = static_cast<U>(src[idx.x]); dst[idx.y] = static_cast<U>(src[idx.x]);

View File

@@ -2,27 +2,22 @@
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/copy.h" #include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy_all(tname, itype, otype) \ #define instantiate_copy_all(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \ instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \ instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \ instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \ instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \ instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \ instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \ instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \ instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
#define instantiate_copy_itype(itname, itype) \ #define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \ instantiate_copy_all(itname ##bool_, itype, bool) \

View File

@@ -4,7 +4,7 @@
#include "mlx/backend/metal/kernels/indexing.h" #include "mlx/backend/metal/kernels/indexing.h"
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT> template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
METAL_FUNC void gather_impl( METAL_FUNC void gather_impl(
const device T* src [[buffer(0)]], const device T* src [[buffer(0)]],
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
@@ -16,18 +16,18 @@ METAL_FUNC void gather_impl(
const thread Indices<IdxT, NIDX>& indices, const thread Indices<IdxT, NIDX>& indices,
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
LocT src_idx = 0; size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) { for (int i = 0; i < NIDX; ++i) {
LocT idx_loc; size_t idx_loc;
if (IDX_NDIM == 0) { if (IDX_NDIM == 0) {
idx_loc = 0; idx_loc = 0;
} else if (IDX_NDIM == 1) { } else if (IDX_NDIM == 1) {
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]); idx_loc = index.x * indices.strides[indices.ndim * i];
} else { } else {
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]); idx_loc = index.x * indices.strides[indices.ndim * i];
idx_loc += indices.row_contiguous[i] idx_loc += indices.row_contiguous[i]
? index.y ? index.y
: elem_to_loc<size_t, LocT>( : elem_to_loc(
index.y, index.y,
&indices.shapes[indices.ndim * i + 1], &indices.shapes[indices.ndim * i + 1],
&indices.strides[indices.ndim * i + 1], &indices.strides[indices.ndim * i + 1],
@@ -35,17 +35,17 @@ METAL_FUNC void gather_impl(
} }
auto ax = axes[i]; auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]); src_idx += idx_val * src_strides[ax];
} }
auto src_offset = auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
elem_to_loc<size_t, LocT>(index.z, slice_sizes, src_strides, src_ndim);
LocT out_idx = index.z; size_t out_idx = index.z;
if (IDX_NDIM == 1) { if (IDX_NDIM == 1) {
out_idx += static_cast<LocT>(grid_dim.z) * index.x; out_idx += static_cast<size_t>(grid_dim.z) * index.x;
} else if (IDX_NDIM >= 2) { } else if (IDX_NDIM >= 2) {
out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y); out_idx +=
grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
} }
out[out_idx] = src[src_offset + src_idx]; out[out_idx] = src[src_offset + src_idx];
} }

View File

@@ -3,6 +3,8 @@
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/steel/utils.h"
@@ -910,4 +912,4 @@ template <
// clang-format off // clang-format off
instantiate_gemv_t_bs_blocks(float32, float); instantiate_gemv_t_bs_blocks(float32, float);
instantiate_gemv_t_bs_blocks(float16, half); instantiate_gemv_t_bs_blocks(float16, half);
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@@ -4,6 +4,8 @@
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/gemv_masked.h" #include "mlx/backend/metal/kernels/gemv_masked.h"

View File

@@ -14,7 +14,7 @@ struct Indices {
}; };
template <typename IdxT> template <typename IdxT>
METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) { METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
if (is_unsigned_v<IdxT>) { if (is_unsigned_v<IdxT>) {
return idx; return idx;
} else { } else {

View File

@@ -1,16 +0,0 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#define jit_if #if
#define jit_else #else
#define jit_endif #endif
jit_if (__METAL_VERSION__ >= 310)
#include "mlx/backend/metal/kernels/metal_3_1/bf16.h"
jit_else
#include "mlx/backend/metal/kernels/metal_3_0/bf16.h"
jit_endif // clang-format on

View File

@@ -3,6 +3,8 @@
#include <metal_common> #include <metal_common>
#include <metal_simdgroup> #include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
using namespace metal; using namespace metal;

View File

@@ -1,16 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_stdlib>
using namespace metal;
typedef bfloat bfloat16_t;
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
return as_type<uint16_t>(x);
}
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
return as_type<bfloat16_t>(x);
}

View File

@@ -13,8 +13,8 @@ MLX_MTL_CONST int QUAD_SIZE = 4;
template <typename T, typename U, int values_per_thread, int bits> template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector(const device T* x, thread U* x_thread) { inline U load_vector(const device T* x, thread U* x_thread) {
static_assert( static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}"); "Template undefined for bits not in {2, 4, 8}");
U sum = 0; U sum = 0;
@@ -28,21 +28,6 @@ inline U load_vector(const device T* x, thread U* x_thread) {
} }
} }
else if (bits == 3) {
for (int i = 0; i < values_per_thread; i += 8) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
x[i + 6] + x[i + 7];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 8.0f;
x_thread[i + 2] = x[i + 2] / 64.0f;
x_thread[i + 3] = x[i + 3] / 2.0f;
x_thread[i + 4] = x[i + 4] / 16.0f;
x_thread[i + 5] = x[i + 5] / 128.0f;
x_thread[i + 6] = x[i + 6] / 4.0f;
x_thread[i + 7] = x[i + 7] / 32.0f;
}
}
else if (bits == 4) { else if (bits == 4) {
for (int i = 0; i < values_per_thread; i += 4) { for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
@@ -53,16 +38,6 @@ inline U load_vector(const device T* x, thread U* x_thread) {
} }
} }
else if (bits == 6) {
for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 64.0f;
x_thread[i + 2] = x[i + 2] / 16.0f;
x_thread[i + 3] = x[i + 3] / 4.0f;
}
}
else if (bits == 8) { else if (bits == 8) {
for (int i = 0; i < values_per_thread; i++) { for (int i = 0; i < values_per_thread; i++) {
sum += x[i]; sum += x[i];
@@ -76,8 +51,8 @@ inline U load_vector(const device T* x, thread U* x_thread) {
template <typename T, typename U, int values_per_thread, int bits> template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
static_assert( static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}"); "Template undefined for bits not in {2, 4, 8}");
U sum = 0; U sum = 0;
@@ -89,21 +64,8 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 2] = x[i + 2] / 16.0f;
x_thread[i + 3] = x[i + 3] / 64.0f; x_thread[i + 3] = x[i + 3] / 64.0f;
} }
} for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
else if (bits == 3) {
for (int i = 0; i < N; i += 8) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
x[i + 6] + x[i + 7];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 8.0f;
x_thread[i + 2] = x[i + 2] / 64.0f;
x_thread[i + 3] = x[i + 3] / 2.0f;
x_thread[i + 4] = x[i + 4] / 16.0f;
x_thread[i + 5] = x[i + 5] / 128.0f;
x_thread[i + 6] = x[i + 6] / 4.0f;
x_thread[i + 7] = x[i + 7] / 32.0f;
} }
} }
@@ -115,15 +77,8 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
x_thread[i + 2] = x[i + 2] / 256.0f; x_thread[i + 2] = x[i + 2] / 256.0f;
x_thread[i + 3] = x[i + 3] / 4096.0f; x_thread[i + 3] = x[i + 3] / 4096.0f;
} }
} for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
else if (bits == 6) {
for (int i = 0; i < N; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 64.0f;
x_thread[i + 2] = x[i + 2] / 16.0f;
x_thread[i + 3] = x[i + 3] / 4.0f;
} }
} }
@@ -132,10 +87,9 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
sum += x[i]; sum += x[i];
x_thread[i] = x[i]; x_thread[i] = x[i];
} }
} for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
for (int i = N; i < values_per_thread; i++) { }
x_thread[i] = 0;
} }
return sum; return sum;
@@ -149,8 +103,8 @@ inline U qdot(
U bias, U bias,
U sum) { U sum) {
static_assert( static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}"); "Template undefined for bits not in {2, 4, 8}");
U accum = 0; U accum = 0;
@@ -164,26 +118,6 @@ inline U qdot(
} }
} }
else if (bits == 3) {
for (int i = 0; i < (values_per_thread / 8); i++) {
x_thread += 8 * i;
w += 3 * i;
accum += (w[0] & 0x07) * x_thread[0];
accum += (w[0] & 0x38) * x_thread[1];
accum += (w[0] & 0xc0) * x_thread[2];
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
accum += (w[1] & 0x0e) * x_thread[3];
accum += (w[1] & 0x70) * x_thread[4];
accum += (w[1] & 0x80) * x_thread[5];
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
accum += (w[2] & 0x1c) * x_thread[6];
accum += (w[2] & 0xe0) * x_thread[7];
}
}
else if (bits == 4) { else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w; const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (values_per_thread / 4); i++) { for (int i = 0; i < (values_per_thread / 4); i++) {
@@ -195,23 +129,6 @@ inline U qdot(
} }
} }
else if (bits == 6) {
for (int i = 0; i < (values_per_thread / 4); i++) {
x_thread += 4 * i;
w += 3 * i;
accum += (w[0] & 0x3f) * x_thread[0];
accum += (w[0] & 0xc0) * x_thread[1];
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
accum += (w[1] & 0xf0) * x_thread[2];
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
accum += (w[2] & 0xfc) * x_thread[3];
}
}
else if (bits == 8) { else if (bits == 8) {
for (int i = 0; i < values_per_thread; i++) { for (int i = 0; i < values_per_thread; i++) {
accum += x_thread[i] * w[i]; accum += x_thread[i] * w[i];
@@ -230,8 +147,8 @@ inline U qdot_safe(
U sum, U sum,
int N) { int N) {
static_assert( static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}"); "Template undefined for bits not in {2, 4, 8}");
U accum = 0; U accum = 0;
@@ -245,26 +162,6 @@ inline U qdot_safe(
} }
} }
else if (bits == 3) {
for (int i = 0; i < (N / 8); i++) {
x_thread += 8 * i;
w += 3 * i;
accum += (w[0] & 0x07) * x_thread[0];
accum += (w[0] & 0x38) * x_thread[1];
accum += (w[0] & 0xc0) * x_thread[2];
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
accum += (w[1] & 0x0e) * x_thread[3];
accum += (w[1] & 0x70) * x_thread[4];
accum += (w[1] & 0x80) * x_thread[5];
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
accum += (w[2] & 0x1c) * x_thread[6];
accum += (w[2] & 0xe0) * x_thread[7];
}
}
else if (bits == 4) { else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w; const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (N / 4); i++) { for (int i = 0; i < (N / 4); i++) {
@@ -276,23 +173,6 @@ inline U qdot_safe(
} }
} }
else if (bits == 6) {
for (int i = 0; i < (N / 4); i++) {
x_thread += 4 * i;
w += 3 * i;
accum += (w[0] & 0x3f) * x_thread[0];
accum += (w[0] & 0xc0) * x_thread[1];
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
accum += (w[1] & 0xf0) * x_thread[2];
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
accum += (w[2] & 0xfc) * x_thread[3];
}
}
else if (bits == 8) { else if (bits == 8) {
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
accum += x_thread[i] * w[i]; accum += x_thread[i] * w[i];
@@ -306,8 +186,8 @@ template <typename U, int values_per_thread, int bits>
inline void inline void
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
static_assert( static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}"); "Template undefined for bits not in {2, 4, 8}");
if (bits == 2) { if (bits == 2) {
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
@@ -319,45 +199,12 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
} }
} }
else if (bits == 3) {
for (int i = 0; i < (values_per_thread / 8); i++) {
uint8_t w0 = w[3 * i];
uint8_t w1 = w[3 * i + 1];
uint8_t w2 = w[3 * i + 2];
result[8 * i] += x * ((w0 & 0x7) * scale + bias);
result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
result[8 * i + 2] +=
x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
result[8 * i + 5] +=
x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
}
}
else if (bits == 4) { else if (bits == 4) {
U s[2] = {scale, scale / 16.0f}; U s[2] = {scale, scale / 16.0f};
for (int i = 0; i < (values_per_thread / 2); i++) { for (int i = 0; i < (values_per_thread / 2); i++) {
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
} }
} else if (bits == 6) {
for (int i = 0; i < (values_per_thread / 4); i++) {
uint8_t w0 = w[3 * i];
uint8_t w1 = w[3 * i + 1];
uint8_t w2 = w[3 * i + 2];
result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
result[4 * i + 1] +=
x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
result[4 * i + 2] +=
x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
}
} }
else if (bits == 8) { else if (bits == 8) {
@@ -371,8 +218,8 @@ template <typename U, int N, int bits>
inline void inline void
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
static_assert( static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}"); "Template undefined for bits not in {2, 4, 8}");
if (bits == 2) { if (bits == 2) {
U s[4] = { U s[4] = {
@@ -388,22 +235,6 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
} }
} }
else if (bits == 3) {
for (int i = 0; i < (N / 8); i++) {
w_local += 8 * i;
w += 3 * i;
w_local[0] = (w[0] & 0x7) * scale + bias;
w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
}
}
else if (bits == 4) { else if (bits == 4) {
U s[2] = {scale, scale / static_cast<U>(16.0f)}; U s[2] = {scale, scale / static_cast<U>(16.0f)};
for (int i = 0; i < (N / 2); i++) { for (int i = 0; i < (N / 2); i++) {
@@ -412,18 +243,6 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
} }
} }
else if (bits == 6) {
for (int i = 0; i < (N / 4); i++) {
w_local += 4 * i;
w += 3 * i;
w_local[0] = (w[0] & 0x3f) * scale + bias;
w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
}
}
else if (bits == 8) { else if (bits == 8) {
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
w_local[i] = scale * w[i] + bias; w_local[i] = scale * w[i] + bias;
@@ -448,11 +267,10 @@ struct QuantizedBlockLoader {
group_size % BCOLS == 0, group_size % BCOLS == 0,
"The group size should be divisible by the columns"); "The group size should be divisible by the columns");
static_assert( static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}"); "Template undefined for bits not in {2, 4, 8}");
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; MLX_MTL_CONST short pack_factor = 32 / bits;
MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
MLX_MTL_CONST short n_reads = MLX_MTL_CONST short n_reads =
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
@@ -468,12 +286,12 @@ struct QuantizedBlockLoader {
const short bj; const short bj;
threadgroup T* dst; threadgroup T* dst;
const device uint8_t* src; const device uint32_t* src;
const device T* scales; const device T* scales;
const device T* biases; const device T* biases;
QuantizedBlockLoader( QuantizedBlockLoader(
const device uint8_t* src_, const device uint32_t* src_,
const device T* scales_, const device T* scales_,
const device T* biases_, const device T* biases_,
const int src_ld_, const int src_ld_,
@@ -482,16 +300,14 @@ struct QuantizedBlockLoader {
ushort simd_lane_id [[thread_index_in_simdgroup]]) ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_), : src_ld(src_ld_),
tile_stride( tile_stride(
reduction_dim ? BCOLS_PACKED * bytes_per_pack reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
: BROWS * src_ld * bytes_per_pack / pack_factor),
group_step_cnt(0), group_step_cnt(0),
group_stride(BROWS * src_ld / group_size), group_stride(BROWS * src_ld / group_size),
thread_idx(simd_group_id * 32 + simd_lane_id), thread_idx(simd_group_id * 32 + simd_lane_id),
bi(n_reads * thread_idx / BCOLS_PACKED), bi(n_reads * thread_idx / BCOLS_PACKED),
bj((n_reads * thread_idx) % BCOLS_PACKED), bj((n_reads * thread_idx) % BCOLS_PACKED),
dst(dst_ + bi * dst_ld + bj * pack_factor), dst(dst_ + bi * dst_ld + bj * pack_factor),
src(src_ + bi * src_ld * bytes_per_pack / pack_factor + src(src_ + bi * src_ld / pack_factor + bj),
bj * bytes_per_pack),
scales(scales_ + bi * src_ld / group_size), scales(scales_ + bi * src_ld / group_size),
biases(biases_ + bi * src_ld / group_size) {} biases(biases_ + bi * src_ld / group_size) {}
@@ -504,7 +320,7 @@ struct QuantizedBlockLoader {
T bias = *biases; T bias = *biases;
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>( dequantize<T, pack_factor, bits>(
src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
} }
} }
@@ -531,10 +347,7 @@ struct QuantizedBlockLoader {
T bias = *biases; T bias = *biases;
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>( dequantize<T, pack_factor, bits>(
(device uint8_t*)(src + i * bytes_per_pack), (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
scale,
bias,
dst + i * pack_factor);
} }
} }
@@ -597,7 +410,8 @@ METAL_FUNC void qmv_quad_impl(
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_quadgroup; row++) { for (int row = 0; row < results_per_quadgroup; row++) {
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
const device T* sl = scales + row * in_vec_size_g * quads_per_simd; const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
const device T* bl = biases + row * in_vec_size_g * quads_per_simd; const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
@@ -628,30 +442,25 @@ METAL_FUNC void qmv_fast_impl(
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int packs_per_thread = bits > 2 ? 2 : 1;
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
constexpr int num_simdgroups = 2; constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4; constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; constexpr int pack_factor = 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread; constexpr int scale_step_per_thread = group_size / values_per_thread;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U; typedef float U;
thread U x_thread[values_per_thread]; thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0}; thread U result[results_per_simdgroup] = {0};
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size; const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup; simd_gid * results_per_simdgroup;
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread;
@@ -661,7 +470,8 @@ METAL_FUNC void qmv_fast_impl(
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g; const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g;
@@ -670,7 +480,7 @@ METAL_FUNC void qmv_fast_impl(
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum); result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
} }
ws += block_size * bytes_per_pack / pack_factor; w += block_size / pack_factor;
scales += block_size / group_size; scales += block_size / group_size;
biases += block_size / group_size; biases += block_size / group_size;
x += block_size; x += block_size;
@@ -696,25 +506,21 @@ METAL_FUNC void qmv_impl(
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int num_simdgroups = 2; constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4; constexpr int results_per_simdgroup = 4;
constexpr int packs_per_thread = 1; constexpr int packs_per_thread = 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; constexpr int pack_factor = 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread; constexpr int scale_step_per_thread = group_size / values_per_thread;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U; typedef float U;
thread U x_thread[values_per_thread]; thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0}; thread U result[results_per_simdgroup] = {0};
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size; const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup; simd_gid * results_per_simdgroup;
@@ -727,8 +533,7 @@ METAL_FUNC void qmv_impl(
// In this case we need to properly guard all our reads because there isn't // In this case we need to properly guard all our reads because there isn't
// even 1 tile in the matrix // even 1 tile in the matrix
if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
ws += w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread;
@@ -739,7 +544,8 @@ METAL_FUNC void qmv_impl(
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; out_row + row < out_vec_size; row++) { for (int row = 0; out_row + row < out_vec_size; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g; const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g;
@@ -749,7 +555,7 @@ METAL_FUNC void qmv_impl(
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum); qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
} }
ws += block_size * bytes_per_pack / pack_factor; w += block_size / pack_factor;
scales += block_size / group_size; scales += block_size / group_size;
biases += block_size / group_size; biases += block_size / group_size;
x += block_size; x += block_size;
@@ -758,20 +564,18 @@ METAL_FUNC void qmv_impl(
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
0, 0,
values_per_thread); values_per_thread);
if (remaining > 0) { U sum =
U sum = load_vector_safe<T, U, values_per_thread, bits>( load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
x, x_thread, remaining);
for (int row = 0; out_row + row < out_vec_size; row++) { for (int row = 0; out_row + row < out_vec_size; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device uint8_t* wl =
const device T* sl = scales + row * in_vec_size_g; (const device uint8_t*)(w + row * in_vec_size_w);
const device T* bl = biases + row * in_vec_size_g; const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
U s = sl[0]; U s = sl[0];
U b = bl[0]; U b = bl[0];
result[row] += result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
} }
for (int row = 0; out_row + row < out_vec_size; row++) { for (int row = 0; out_row + row < out_vec_size; row++) {
@@ -784,8 +588,7 @@ METAL_FUNC void qmv_impl(
// In this case the last tile is moved back to redo some output values // In this case the last tile is moved back to redo some output values
else { else {
ws += used_out_row * in_vec_size_w + w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
simd_lid * packs_per_thread * bytes_per_pack;
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread;
@@ -796,7 +599,8 @@ METAL_FUNC void qmv_impl(
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g; const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g;
@@ -806,7 +610,7 @@ METAL_FUNC void qmv_impl(
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum); qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
} }
ws += block_size * bytes_per_pack / pack_factor; w += block_size / pack_factor;
scales += block_size / group_size; scales += block_size / group_size;
biases += block_size / group_size; biases += block_size / group_size;
x += block_size; x += block_size;
@@ -815,21 +619,21 @@ METAL_FUNC void qmv_impl(
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
0, 0,
values_per_thread); values_per_thread);
if (remaining > 0) { U sum =
U sum = load_vector_safe<T, U, values_per_thread, bits>( load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
x, x_thread, remaining);
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device uint8_t* wl =
const device T* sl = scales + row * in_vec_size_g; (const device uint8_t*)(w + row * in_vec_size_w);
const device T* bl = biases + row * in_vec_size_g; const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
U s = sl[0]; U s = sl[0];
U b = bl[0]; U b = bl[0];
result[row] += qdot_safe<U, values_per_thread, bits>( result[row] += qdot_safe<U, values_per_thread, bits>(
wl, x_thread, s, b, sum, remaining); wl, x_thread, s, b, sum, remaining);
}
} }
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]); result[row] = simd_sum(result[row]);
if (simd_lid == 0) { if (simd_lid == 0) {
@@ -851,18 +655,14 @@ METAL_FUNC void qvm_impl(
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int num_simdgroups = 2; constexpr int num_simdgroups = 2;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; constexpr int pack_factor = 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int tn = 32 / pack_factor; constexpr int tn = 32 / pack_factor;
constexpr int block_size = SIMD_SIZE; constexpr int blocksize = SIMD_SIZE;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U; typedef float U;
typedef struct { typedef struct {
uint8_t wi[tn * bytes_per_pack]; uint32_t wi[tn];
} vec_w; } vec_w;
thread vec_w w_local; thread vec_w w_local;
@@ -872,10 +672,11 @@ METAL_FUNC void qvm_impl(
thread U x_local = 0; thread U x_local = 0;
// Adjust positions // Adjust positions
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; const int out_vec_size_w = out_vec_size / pack_factor;
const int out_vec_size_g = out_vec_size / group_size; const int out_vec_size_g = out_vec_size / group_size;
int out_col = pack_factor * tn * (tid.x * num_simdgroups + simd_gid); int out_col =
ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
w += out_col / pack_factor + simd_lid * out_vec_size_w;
scales += out_col / group_size + simd_lid * out_vec_size_g; scales += out_col / group_size + simd_lid * out_vec_size_g;
biases += out_col / group_size + simd_lid * out_vec_size_g; biases += out_col / group_size + simd_lid * out_vec_size_g;
x += tid.y * in_vec_size + simd_lid; x += tid.y * in_vec_size + simd_lid;
@@ -885,42 +686,43 @@ METAL_FUNC void qvm_impl(
return; return;
} }
// Loop over in_vec in blocks of block_size // Loop over in_vec in blocks of blocksize
int remaining = in_vec_size % block_size; int remaining = in_vec_size % blocksize;
if (remaining == 0) { if (remaining == 0) {
for (int i = 0; i < in_vec_size; i += block_size) { for (int i = 0; i < in_vec_size; i += blocksize) {
x_local = *x; x_local = *x;
scale = *scales; scale = *scales;
bias = *biases; bias = *biases;
w_local = *((device vec_w*)ws); w_local = *((device vec_w*)w);
qouter<U, tn * pack_factor, bits>( qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result); (thread uint8_t*)&w_local, x_local, scale, bias, result);
x += block_size; x += blocksize;
scales += block_size * out_vec_size_g; scales += blocksize * out_vec_size_g;
biases += block_size * out_vec_size_g; biases += blocksize * out_vec_size_g;
ws += block_size * out_vec_size_w; w += blocksize * out_vec_size_w;
} }
} else { } else {
for (int i = block_size; i < in_vec_size; i += block_size) { for (int i = blocksize; i < in_vec_size; i += blocksize) {
x_local = *x; x_local = *x;
scale = *scales; scale = *scales;
bias = *biases; bias = *biases;
w_local = *((device vec_w*)ws); w_local = *((device vec_w*)w);
qouter<U, tn * pack_factor, bits>( qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result); (thread uint8_t*)&w_local, x_local, scale, bias, result);
x += block_size; x += blocksize;
scales += block_size * out_vec_size_g; scales += blocksize * out_vec_size_g;
biases += block_size * out_vec_size_g; biases += blocksize * out_vec_size_g;
ws += block_size * out_vec_size_w; w += blocksize * out_vec_size_w;
} }
if (static_cast<int>(simd_lid) < remaining) { if (static_cast<int>(simd_lid) < remaining) {
x_local = *x; x_local = *x;
scale = *scales; scale = *scales;
bias = *biases; bias = *biases;
w_local = *((device vec_w*)ws); w_local = *((device vec_w*)w);
} else { } else {
x_local = 0; x_local = 0;
scale = 0; scale = 0;
@@ -975,9 +777,8 @@ METAL_FUNC void qmm_t_impl(
constexpr int WM = 2; constexpr int WM = 2;
constexpr int WN = 2; constexpr int WN = 2;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int pack_factor = 32 / bits;
constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
// Instantiate the appropriate BlockMMA and Loader // Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel:: using mma_t = mlx::steel::
@@ -995,15 +796,13 @@ METAL_FUNC void qmm_t_impl(
bits>; bits>;
// Set the block // Set the block
const int K_w = K * bytes_per_pack / pack_factor; const int K_w = K / pack_factor;
const int K_g = K / group_size; const int K_g = K / group_size;
const int y_row = tid.y * BM; const int y_row = tid.y * BM;
const int y_col = tid.x * BN; const int y_col = tid.x * BN;
auto wl = (const device uint8_t*)w;
x += y_row * K; x += y_row * K;
wl += y_col * K_w; w += y_col * K_w;
scales += y_col * K_g; scales += y_col * K_g;
biases += y_col * K_g; biases += y_col * K_g;
y += y_row * N + y_col; y += y_row * N + y_col;
@@ -1012,7 +811,7 @@ METAL_FUNC void qmm_t_impl(
const short num_els = min(BM, M - y_row); const short num_els = min(BM, M - y_row);
const short num_outs = min(BN, N - y_col); const short num_outs = min(BN, N - y_col);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid);
if (num_els < BM) { if (num_els < BM) {
@@ -1054,7 +853,6 @@ METAL_FUNC void qmm_t_impl(
loader_x.load_unsafe(); loader_x.load_unsafe();
loader_w.load_unsafe(); loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws); mma_op.mma(Xs, Ws);
loader_x.next(); loader_x.next();
loader_w.next(); loader_w.next();
@@ -1100,11 +898,9 @@ METAL_FUNC void qmm_n_impl(
constexpr int WM = 2; constexpr int WM = 2;
constexpr int WN = 2; constexpr int WN = 2;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int pack_factor = 32 / bits;
constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
// Instantiate the appropriate BlockMMA and Loader // Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel:: using mma_t = mlx::steel::
@@ -1121,13 +917,11 @@ METAL_FUNC void qmm_n_impl(
group_size, group_size,
bits>; bits>;
auto wl = (const device uint8_t*)w;
// Set the block // Set the block
const int y_row = tid.y * BM; const int y_row = tid.y * BM;
const int y_col = tid.x * BN; const int y_col = tid.x * BN;
x += y_row * K; x += y_row * K;
wl += y_col * bytes_per_pack / pack_factor; w += y_col / pack_factor;
scales += y_col / group_size; scales += y_col / group_size;
biases += y_col / group_size; biases += y_col / group_size;
y += y_row * N + y_col; y += y_row * N + y_col;
@@ -1135,7 +929,7 @@ METAL_FUNC void qmm_n_impl(
// Make the x loader and mma operation // Make the x loader and mma operation
const short num_els = min(BM, M - y_row); const short num_els = min(BM, M - y_row);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid); loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid);
if (num_els < BM) { if (num_els < BM) {
@@ -2007,14 +1801,13 @@ template <typename T, const int group_size, const int bits>
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
constexpr T eps = T(1e-7); constexpr T eps = T(1e-7);
constexpr int simd_size = 32; constexpr int simd_size = 32;
constexpr int uint8_bits = 8;
constexpr T n_bins = (1 << bits) - 1; constexpr T n_bins = (1 << bits) - 1;
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int packs_per_int = uint8_bits / bits;
constexpr int values_per_reduce = group_size / simd_size; constexpr int values_per_reduce = group_size / simd_size;
constexpr int writes_per_reduce = packs_per_int / values_per_reduce; constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
constexpr int writes_per_pack = constexpr int writes_per_pack =
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int; writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
static_assert( static_assert(
group_size % simd_size == 0, group_size % simd_size == 0,
@@ -2022,9 +1815,7 @@ template <typename T, const int group_size, const int bits>
size_t offset = index.x + grid_dim.x * size_t(index.y); size_t offset = index.x + grid_dim.x * size_t(index.y);
size_t in_index = offset * values_per_reduce; size_t in_index = offset * values_per_reduce;
size_t out_index = power_of_2_bits size_t out_index = offset * writes_per_pack;
? offset * writes_per_pack
: offset * bytes_per_pack / writes_per_reduce;
T w_thread[values_per_reduce]; T w_thread[values_per_reduce];
T w_min = Limits<T>::max; T w_min = Limits<T>::max;
@@ -2057,9 +1848,7 @@ template <typename T, const int group_size, const int bits>
biases[gindex] = bias; biases[gindex] = bias;
} }
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t uint8_t output = 0;
uint32_t output = 0;
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) { for (int i = 0; i < values_per_reduce; i++) {
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
@@ -2075,23 +1864,47 @@ template <typename T, const int group_size, const int bits>
output = 0; output = 0;
} else { } else {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int j = 1; j < writes_per_reduce; j++) { for (int j = 0; j < writes_per_reduce - 1; j++) {
uint8_t sval = simd_shuffle_down(val, j); uint8_t sval = simd_shuffle_down(val, j + 1);
output += sval << (bits * (j * values_per_reduce + i)); output += sval << (bits * (values_per_reduce + j + i));
} }
} }
} }
if (bits == 3 || bits == 6) { if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) { out[out_index / writes_per_reduce] = output;
out[out_index] = output & 0xff; }
out[out_index + 1] = (output & 0xff00) >> 8; }
out[out_index + 2] = (output & 0xff0000) >> 16;
} template <typename T, const int group_size, const int bits>
} else { [[kernel]] void affine_quantize_scales_biases(
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { const device T* w [[buffer(0)]],
out[out_index / writes_per_reduce] = output; const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
device uint8_t* out [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr int uint8_bits = 8;
constexpr int packs_per_int = uint8_bits / bits;
constexpr T n_bins = (1 << bits) - 1;
size_t offset = index.x + grid_dim.x * size_t(index.y);
size_t in_index = offset * packs_per_int;
size_t gindex = in_index / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
uint8_t output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
if (bits == 8) {
output = val;
} else {
output += val << (bits * i);
} }
} }
out[offset] = output;
} }
template <typename T, const int group_size, const int bits> template <typename T, const int group_size, const int bits>
@@ -2102,48 +1915,26 @@ template <typename T, const int group_size, const int bits>
device T* out [[buffer(3)]], device T* out [[buffer(3)]],
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int uint8_bits = 8;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int packs_per_int = uint8_bits / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
size_t offset = index.x + grid_dim.x * size_t(index.y); size_t offset = index.x + grid_dim.x * size_t(index.y);
size_t oindex = offset * packs_per_int; size_t oindex = offset * packs_per_int;
size_t gindex = oindex / group_size; size_t gindex = oindex / group_size;
T scale = scales[gindex]; T scale = scales[gindex];
T bias = biases[gindex]; T bias = biases[gindex];
uint val = w[offset];
out += oindex;
if (bits == 3) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x7) * scale + bias;
out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
} else if (bits == 6) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x3f) * scale + bias;
out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
} else {
uint val = w[offset];
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) { for (int i = 0; i < packs_per_int; i++) {
uint8_t d; uint8_t d;
if (bits == 2) { if (bits == 2) {
d = (val >> (bits * i)) & 0x03; d = (val >> (bits * i)) & 0x03;
} else if (bits == 4) { } else if (bits == 4) {
d = (val >> (bits * i)) & 0x0f; d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) { } else if (bits == 8) {
d = val; d = val;
}
out[i] = scale * d + bias;
} }
out[oindex + i] = scale * d + bias;
} }
} }

View File

@@ -72,6 +72,7 @@
#define instantiate_quantized_all_single(type, group_size, bits) \ #define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \ instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_quantize_scales_biases, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \ instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \ instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \ instantiate_quantized(bs_qmv, type, group_size, bits) \
@@ -115,9 +116,7 @@
#define instantiate_quantized_all() \ #define instantiate_quantized_all() \
instantiate_quantized_groups(2) \ instantiate_quantized_groups(2) \
instantiate_quantized_groups(3) \
instantiate_quantized_groups(4) \ instantiate_quantized_groups(4) \
instantiate_quantized_groups(6) \
instantiate_quantized_groups(8) instantiate_quantized_groups(8)
instantiate_quantized_all() // clang-format on instantiate_quantized_all() // clang-format on

View File

@@ -10,156 +10,186 @@
#include "mlx/backend/metal/kernels/reduction/ops.h" #include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduce.h" #include "mlx/backend/metal/kernels/reduce.h"
#define instantiate_init_reduce(name, tname, type, op) \ #define instantiate_reduce_helper_floats(inst_f, name, op) \
instantiate_kernel("init_reduce_" #name #tname, init_reduce, type, op<type>) inst_f(name, float16, half, op) \
inst_f(name, float32, float, op) \
inst_f(name, bfloat16, bfloat16_t, op)
instantiate_init_reduce(and, bool_, bool, And) #define instantiate_reduce_helper_uints(inst_f, name, op) \
instantiate_init_reduce(or, bool_, bool, Or) inst_f(name, uint8, uint8_t, op) \
inst_f(name, uint16, uint16_t, op) \
inst_f(name, uint32, uint32_t, op)
#define instantiate_init_sum_prod(name, op) \ #define instantiate_reduce_helper_ints(inst_f, name, op) \
instantiate_init_reduce(name, int32, int32_t, op) \ inst_f(name, int8, int8_t, op) \
instantiate_init_reduce(name, int64, int64_t, op) \ inst_f(name, int16, int16_t, op) \
instantiate_init_reduce(name, float16, float16_t, op) \ inst_f(name, int32, int32_t, op)
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
instantiate_init_reduce(name, float32, float, op) \
instantiate_init_reduce(name, complex64, complex64_t, op)
instantiate_init_sum_prod(sum, Sum) #define instantiate_reduce_helper_64b(inst_f, name, op) \
instantiate_init_sum_prod(prod, Prod) inst_f(name, int64, int64_t, op) \
inst_f(name, uint64, uint64_t, op) \
inst_f(name, complex64, complex64_t, op)
#define instantiate_init_min_max(name, op) \ #define instantiate_reduce_helper_types(inst_f, name, op) \
instantiate_init_reduce(name, bool_, bool, op) \ instantiate_reduce_helper_floats(inst_f, name, op) \
instantiate_init_reduce(name, int8, int8_t, op) \ instantiate_reduce_helper_uints(inst_f, name, op) \
instantiate_init_reduce(name, int16, int16_t, op) \ instantiate_reduce_helper_ints(inst_f, name, op)
instantiate_init_reduce(name, int32, int32_t, op) \
instantiate_init_reduce(name, int64, int64_t, op) \
instantiate_init_reduce(name, uint8, uint8_t, op) \
instantiate_init_reduce(name, uint16, uint16_t, op) \
instantiate_init_reduce(name, uint32, uint32_t, op) \
instantiate_init_reduce(name, uint64, uint64_t, op) \
instantiate_init_reduce(name, float16, float16_t, op) \
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
instantiate_init_reduce(name, float32, float, op) \
instantiate_init_reduce(name, complex64, complex64_t, op)
instantiate_init_min_max(min, Min) #define instantiate_reduce_ops(inst_f, type_f) \
instantiate_init_min_max(max, Max) type_f(inst_f, sum, Sum) \
type_f(inst_f, prod, Prod) \
type_f(inst_f, min, Min) \
type_f(inst_f, max, Max)
// Special case for bool reductions
#define instantiate_reduce_from_types_helper( \
inst_f, name, tname, itype, otype, op) \
inst_f(name##tname, itype, otype, op)
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint64, uint64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, float16, half, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
float32, \
float, \
otype, \
op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
bfloat16, \
bfloat16_t, \
otype, \
op)
#define instantiate_init_reduce(name, otype, op) \
instantiate_kernel("init_reduce_" #name, \
init_reduce, \
otype, op)
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name##tname, type, op<type>)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
instantiate_init_reduce(andbool_, bool, And<bool>)
instantiate_init_reduce(orbool_, bool, Or<bool>)
#define instantiate_all_reduce(name, itype, otype, op) \ #define instantiate_all_reduce(name, itype, otype, op) \
instantiate_kernel("all_reduce_" #name, \ instantiate_kernel("all_reduce_" #name, \
all_reduce, \ all_reduce, \
itype, otype, op) itype, otype, op)
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \ #define instantiate_same_all_reduce_helper(name, tname, type, op) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \ instantiate_all_reduce(name##tname, type, type, op<type>)
col_reduce_small, \
itype, otype, op, uint, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, uint, dim) \
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, size_t, dim) \
instantiate_kernel("col_reduce_longcolumn_large_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, size_t, dim)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \ instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \ instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_64b)
col_reduce_looped, \
itype, otype, op, uint, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, size_t, dim, bm, bn)
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \ instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
col_reduce_2pass, \
itype, otype, op, uint, dim, bm, bn) \ // special case bool with larger output type
instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \ instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
col_reduce_2pass, \
itype, otype, op, size_t, dim, bm, bn) #define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, dim)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \ #define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \ instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32) instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
#define instantiate_col_reduce_general(name, itype, otype, op) \ #define instantiate_col_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_small(name, itype, otype, op, 0) \
instantiate_col_reduce_small(name, itype, otype, op, 1) \ instantiate_col_reduce_small(name, itype, otype, op, 1) \
instantiate_col_reduce_small(name, itype, otype, op, 2) \ instantiate_col_reduce_small(name, itype, otype, op, 2) \
instantiate_col_reduce_small(name, itype, otype, op, 5) \ instantiate_col_reduce_small(name, itype, otype, op, 3) \
instantiate_col_reduce_small(name, itype, otype, op, 4) \
instantiate_col_reduce_looped(name, itype, otype, op, 0) \
instantiate_col_reduce_looped(name, itype, otype, op, 1) \ instantiate_col_reduce_looped(name, itype, otype, op, 1) \
instantiate_col_reduce_looped(name, itype, otype, op, 2) \ instantiate_col_reduce_looped(name, itype, otype, op, 2) \
instantiate_col_reduce_looped(name, itype, otype, op, 5) instantiate_col_reduce_looped(name, itype, otype, op, 3) \
instantiate_col_reduce_looped(name, itype, otype, op, 4)
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \ #define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \ instantiate_col_reduce_general(name##tname, type, type, op<type>)
row_reduce_small, \
itype, otype, op, uint, dim) \
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, size_t, dim)
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \ instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \ instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b)
row_reduce_looped, \
itype, otype, op, uint, dim) \ instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \ instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
row_reduce_looped, \ instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
itype, otype, op, size_t, dim)
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, dim)
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, dim)
#define instantiate_row_reduce_general(name, itype, otype, op) \ #define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op, 0) \
instantiate_row_reduce_small(name, itype, otype, op, 1) \ instantiate_row_reduce_small(name, itype, otype, op, 1) \
instantiate_row_reduce_small(name, itype, otype, op, 2) \ instantiate_row_reduce_small(name, itype, otype, op, 2) \
instantiate_row_reduce_small(name, itype, otype, op, 5) \ instantiate_row_reduce_small(name, itype, otype, op, 3) \
instantiate_row_reduce_small(name, itype, otype, op, 4) \
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
instantiate_row_reduce_looped(name, itype, otype, op, 1) \ instantiate_row_reduce_looped(name, itype, otype, op, 1) \
instantiate_row_reduce_looped(name, itype, otype, op, 2) \ instantiate_row_reduce_looped(name, itype, otype, op, 2) \
instantiate_row_reduce_looped(name, itype, otype, op, 5) \ instantiate_row_reduce_looped(name, itype, otype, op, 3) \
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
instantiate_kernel("row_reduce_simple_" #name, \ instantiate_kernel("row_reduce_simple_" #name, \
row_reduce_simple, \ row_reduce_simple, \
itype, otype, op) itype, otype, op)
#define instantiate_reduce_functions(name, tname, itype, otype, op) \ #define instantiate_same_row_reduce_helper(name, tname, type, op) \
instantiate_all_reduce(name##tname, itype, otype, op<otype>) \ instantiate_row_reduce_general(name##tname, type, type, op<type>)
instantiate_row_reduce_general(name##tname, itype, otype, op<otype>) \
instantiate_col_reduce_general(name##tname, itype, otype, op<otype>)
#define instantiate_and_or(name, op) \ instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_functions(name, bool_, bool, bool, op) \ instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_64b)
instantiate_reduce_functions(name, int16, int16_t, bool, op) \
instantiate_reduce_functions(name, int32, int32_t, bool, op) \
instantiate_reduce_functions(name, int64, int64_t, bool, op)
instantiate_and_or(and, And) instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
instantiate_and_or(or, Or) instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)
#define instantiate_sum_prod(name, op) \ instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
instantiate_reduce_functions(name, float32, float, float, op) \
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
instantiate_sum_prod(sum, Sum)
instantiate_sum_prod(prod, Prod)
#define instantiate_min_max(name, op) \
instantiate_reduce_functions(name, int8, int8_t, int8_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int16_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
instantiate_reduce_functions(name, uint8, uint8_t, uint8_t, op) \
instantiate_reduce_functions(name, uint16, uint16_t, uint16_t, op) \
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
instantiate_reduce_functions(name, float32, float, float, op) \
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
instantiate_min_max(min, Min)
instantiate_min_max(max, Max)
// clang-format on // clang-format on

View File

@@ -1,11 +1,6 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
template < template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
typename T,
typename U,
typename Op,
typename IdxT = int64_t,
int N_READS = REDUCE_N_READS>
[[kernel]] void all_reduce( [[kernel]] void all_reduce(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
@@ -21,10 +16,10 @@ template <
threadgroup U shared_vals[simd_size]; threadgroup U shared_vals[simd_size];
U total = Op::init; U total = Op::init;
IdxT start_idx = gid.y * IdxT(row_size); int64_t start_idx = gid.y * row_size;
IdxT actual_row = int64_t actual_row =
(start_idx + row_size <= in_size) ? row_size : in_size - start_idx; (start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
IdxT blocks = actual_row / (lsize.x * N_READS); int64_t blocks = actual_row / (lsize.x * N_READS);
int extra = actual_row - blocks * (lsize.x * N_READS); int extra = actual_row - blocks * (lsize.x * N_READS);
extra -= lid.x * N_READS; extra -= lid.x * N_READS;
start_idx += lid.x * N_READS; start_idx += lid.x * N_READS;
@@ -35,7 +30,7 @@ template <
extra = 0; extra = 0;
} }
for (IdxT b = 0; b < blocks; b++) { for (int64_t b = 0; b < blocks; b++) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
total = op(static_cast<U>(in[i]), total); total = op(static_cast<U>(in[i]), total);
} }

View File

@@ -1,6 +1,6 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
template <typename T, typename U, typename Op, typename IdxT, int NDIMS> template <typename T, typename U, typename Op, int NDIMS>
[[kernel]] void col_reduce_small( [[kernel]] void col_reduce_small(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
@@ -19,7 +19,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
uint3 lsize [[threads_per_threadgroup]]) { uint3 lsize [[threads_per_threadgroup]]) {
constexpr int n_reads = 4; constexpr int n_reads = 4;
Op op; Op op;
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim); looped_elem_to_loc<NDIMS> loop;
const device T* row; const device T* row;
U totals[n_reads]; U totals[n_reads];
@@ -27,20 +27,20 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
totals[i] = Op::init; totals[i] = Op::init;
} }
IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads; size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads;
if (column >= reduction_stride) { if (column >= reduction_stride) {
return; return;
} }
bool safe = column + n_reads <= reduction_stride; bool safe = column + n_reads <= reduction_stride;
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); size_t out_idx = gid.y + gsize.y * size_t(gid.z);
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim); size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column; in += in_idx + column;
IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); size_t total_rows = non_col_reductions * reduction_size;
loop.next(lid.y, reduce_shape, reduce_strides); loop.next(lid.y, reduce_shape, reduce_strides);
for (IdxT r = lid.y; r < total_rows; r += lsize.y) { for (size_t r = lid.y; r < total_rows; r += lsize.y) {
row = in + loop.location(); row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) { if (safe) {
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]); totals[i] = op(static_cast<U>(row[i]), totals[i]);
@@ -80,7 +80,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
} }
if (lid.y == 0) { if (lid.y == 0) {
out += out_idx * IdxT(reduction_stride) + column; out += out_idx * reduction_stride + column;
if (safe) { if (safe) {
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
out[i] = totals[i]; out[i] = totals[i];
@@ -93,7 +93,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
} }
} }
template <typename T, typename U, typename Op, typename IdxT, int NDIMS> template <typename T, typename U, typename Op, int NDIMS>
[[kernel]] void col_reduce_longcolumn( [[kernel]] void col_reduce_longcolumn(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
@@ -112,19 +112,19 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) { uint3 lsize [[threads_per_threadgroup]]) {
Op op; Op op;
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim); looped_elem_to_loc<NDIMS> loop;
const device T* row; const device T* row;
IdxT out_idx = gid.x + gsize.x * IdxT(gid.y); size_t out_idx = gid.x + gsize.x * size_t(gid.y);
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim); size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + lid.x; in += in_idx + lid.x;
U total = Op::init; U total = Op::init;
IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); size_t total_rows = non_col_reductions * reduction_size;
loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides); loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows; for (size_t r = gid.z * lsize.y + lid.y; r < total_rows;
r += lsize.y * gsize.z) { r += lsize.y * gsize.z) {
row = in + loop.location(); row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
total = op(static_cast<U>(*row), total); total = op(static_cast<U>(*row), total);
loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides); loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
} }
@@ -136,8 +136,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
for (uint i = 1; i < lsize.y; i++) { for (uint i = 1; i < lsize.y; i++) {
total = op(total, shared_vals[i * lsize.x + lid.x]); total = op(total, shared_vals[i * lsize.x + lid.x]);
} }
out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] = out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total;
total;
} }
} }
@@ -152,14 +151,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
* totals with a loop. * totals with a loop.
* 7. Write them to the output * 7. Write them to the output
*/ */
template < template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
typename T,
typename U,
typename Op,
typename IdxT,
int NDIMS,
int BM,
int BN>
[[kernel]] void col_reduce_looped( [[kernel]] void col_reduce_looped(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
@@ -184,7 +176,7 @@ template <
threadgroup U shared_vals[BN * BM]; threadgroup U shared_vals[BN * BM];
U totals[n_reads]; U totals[n_reads];
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim); looped_elem_to_loc<NDIMS> loop;
const device T* row; const device T* row;
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
@@ -193,17 +185,17 @@ template <
short lid = simd_group_id * simd_size + simd_lane_id; short lid = simd_group_id * simd_size + simd_lane_id;
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
IdxT column = BN * gid.x + offset.x; size_t column = BN * gid.x + offset.x;
bool safe = column + n_reads <= reduction_stride; bool safe = column + n_reads <= reduction_stride;
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); size_t out_idx = gid.y + gsize.y * size_t(gid.z);
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim); size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column; in += in_idx + column;
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); size_t total = non_col_reductions * reduction_size;
loop.next(offset.y, reduce_shape, reduce_strides); loop.next(offset.y, reduce_shape, reduce_strides);
for (IdxT r = offset.y; r < total; r += BM) { for (size_t r = offset.y; r < total; r += BM) {
row = in + loop.location(); row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) { if (safe) {
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
@@ -243,8 +235,8 @@ template <
// Write the output. // Write the output.
if (simd_lane_id == 0) { if (simd_lane_id == 0) {
IdxT out_column = BN * gid.x + out_offset.x; size_t out_column = BN * gid.x + out_offset.x;
out += out_idx * IdxT(reduction_stride) + out_column; out += out_idx * reduction_stride + out_column;
if (out_column + n_outputs <= reduction_stride) { if (out_column + n_outputs <= reduction_stride) {
for (int i = 0; i < n_outputs; i++) { for (int i = 0; i < n_outputs; i++) {
out[i] = totals[i]; out[i] = totals[i];
@@ -277,7 +269,7 @@ template <
// Write the output. // Write the output.
if (offset.y == 0) { if (offset.y == 0) {
out += out_idx * IdxT(reduction_stride) + column; out += out_idx * reduction_stride + column;
if (safe) { if (safe) {
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
out[i] = totals[i]; out[i] = totals[i];
@@ -291,14 +283,7 @@ template <
} }
} }
template < template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
typename T,
typename U,
typename Op,
typename IdxT,
int NDIMS,
int BM,
int BN>
[[kernel]] void col_reduce_2pass( [[kernel]] void col_reduce_2pass(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
@@ -327,7 +312,7 @@ template <
threadgroup U shared_vals[BN * BM]; threadgroup U shared_vals[BN * BM];
U totals[n_reads]; U totals[n_reads];
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim); looped_elem_to_loc<NDIMS> loop;
const device T* row; const device T* row;
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
@@ -336,19 +321,20 @@ template <
short lid = simd_group_id * simd_size + simd_lane_id; short lid = simd_group_id * simd_size + simd_lane_id;
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
IdxT column = BN * gid.x + offset.x; size_t column = BN * gid.x + offset.x;
bool safe = column + n_reads <= reduction_stride; bool safe = column + n_reads <= reduction_stride;
IdxT full_idx = gid.y + gsize.y * IdxT(gid.z); size_t full_idx = gid.y + gsize.y * size_t(gid.z);
IdxT block_idx = full_idx / IdxT(out_size); size_t block_idx = full_idx / out_size;
IdxT out_idx = full_idx % IdxT(out_size); size_t out_idx = full_idx % out_size;
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim); size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column; in += in_idx + column;
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); size_t total = non_col_reductions * reduction_size;
loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides); loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) { for (size_t r = offset.y + block_idx * BM; r < total;
row = in + loop.location(); r += outer_blocks * BM) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) { if (safe) {
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
@@ -383,8 +369,8 @@ template <
// Write the output. // Write the output.
if (simd_lane_id == 0) { if (simd_lane_id == 0) {
IdxT out_column = BN * gid.x + out_offset.x; size_t out_column = BN * gid.x + out_offset.x;
out += full_idx * IdxT(reduction_stride) + out_column; out += full_idx * reduction_stride + out_column;
if (out_column + n_outputs <= reduction_stride) { if (out_column + n_outputs <= reduction_stride) {
for (int i = 0; i < n_outputs; i++) { for (int i = 0; i < n_outputs; i++) {
out[i] = totals[i]; out[i] = totals[i];

View File

@@ -193,7 +193,6 @@ template <
typename T, typename T,
typename U, typename U,
typename Op, typename Op,
typename IdxT,
int NDIMS, int NDIMS,
int N_READS = REDUCE_N_READS> int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_small( [[kernel]] void row_reduce_small(
@@ -215,20 +214,20 @@ template <
Op op; Op op;
U total_val = Op::init; U total_val = Op::init;
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim); looped_elem_to_loc<NDIMS> loop;
// Precompute some row reduction numbers // Precompute some row reduction numbers
const device T* row; const device T* row;
int blocks = IdxT(row_size) / N_READS; int blocks = row_size / N_READS;
int extra = IdxT(row_size) % N_READS; int extra = row_size % N_READS;
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
// Simple loop over non_row_reductions and reduce the row in the thread. // Simple loop over non_row_reductions and reduce the row in the thread.
IdxT out_idx = tid.x + tsize.y * IdxT(tid.y); size_t out_idx = tid.x + tsize.y * size_t(tid.y);
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim); in += elem_to_loc(out_idx, shape, strides, ndim);
for (uint r = 0; r < non_row_reductions; r++) { for (uint r = 0; r < non_row_reductions; r++) {
row = in + loop.location(); row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra); thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
loop.next(reduce_shape, reduce_strides); loop.next(reduce_shape, reduce_strides);
} }
@@ -237,13 +236,13 @@ template <
} else { } else {
// Collaboratively reduce over non_row_reductions in the simdgroup. Each // Collaboratively reduce over non_row_reductions in the simdgroup. Each
// thread reduces every 32nd row and then a simple simd reduce. // thread reduces every 32nd row and then a simple simd reduce.
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); size_t out_idx = gid.y + gsize.y * size_t(gid.z);
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim); in += elem_to_loc(out_idx, shape, strides, ndim);
loop.next(simd_lane_id, reduce_shape, reduce_strides); loop.next(simd_lane_id, reduce_shape, reduce_strides);
for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) { for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {
row = in + loop.location(); row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra); thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
loop.next(simd_size, reduce_shape, reduce_strides); loop.next(simd_size, reduce_shape, reduce_strides);
} }
@@ -260,7 +259,6 @@ template <
typename T, typename T,
typename U, typename U,
typename Op, typename Op,
typename IdxT = size_t,
int N_READS = REDUCE_N_READS, int N_READS = REDUCE_N_READS,
int N_WRITES = REDUCE_N_WRITES> int N_WRITES = REDUCE_N_WRITES>
[[kernel]] void row_reduce_simple( [[kernel]] void row_reduce_simple(
@@ -279,15 +277,15 @@ template <
U totals[N_WRITES]; U totals[N_WRITES];
// Move to the row // Move to the row
IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z)); size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z));
if (out_idx + N_WRITES > out_size) { if (out_idx + N_WRITES > out_size) {
out_idx = out_size - N_WRITES; out_idx = out_size - N_WRITES;
} }
in += out_idx * IdxT(reduction_size); in += out_idx * reduction_size;
out += out_idx; out += out_idx;
// Each thread reduces across the row // Each thread reduces across the row
int blocks = IdxT(reduction_size) / (lsize.x * N_READS); int blocks = reduction_size / (lsize.x * N_READS);
int extra = reduction_size - blocks * (lsize.x * N_READS); int extra = reduction_size - blocks * (lsize.x * N_READS);
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>( per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
totals, in, reduction_size, blocks, extra, lsize.x, lid.x); totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
@@ -308,7 +306,6 @@ template <
typename T, typename T,
typename U, typename U,
typename Op, typename Op,
typename IdxT,
int NDIMS, int NDIMS,
int N_READS = REDUCE_N_READS> int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_looped( [[kernel]] void row_reduce_looped(
@@ -333,20 +330,19 @@ template <
threadgroup U shared_vals[simd_size]; threadgroup U shared_vals[simd_size];
U total = Op::init; U total = Op::init;
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); size_t out_idx = gid.y + gsize.y * size_t(gid.z);
// lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
// needs a small refactor. // needs a small refactor.
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim) + in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS;
lid.x * N_READS;
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim); looped_elem_to_loc<NDIMS> loop;
const device T* row; const device T* row;
int blocks = IdxT(row_size) / (lsize.x * N_READS); int blocks = row_size / (lsize.x * N_READS);
int extra = row_size - blocks * (lsize.x * N_READS); int extra = row_size - blocks * (lsize.x * N_READS);
for (IdxT i = 0; i < non_row_reductions; i++) { for (size_t i = 0; i < non_row_reductions; i++) {
row = in + loop.location(); row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim);
// Each thread reduces across the row // Each thread reduces across the row
U row_total; U row_total;

View File

@@ -3,6 +3,8 @@
#include <metal_common> #include <metal_common>
#include <metal_simdgroup> #include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
using namespace metal; using namespace metal;
@@ -15,15 +17,12 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps, constant float& eps,
constant uint& axis_size, constant uint& axis_size,
constant uint& w_stride, constant uint& w_stride,
threadgroup float* local_inv_mean [[threadgroup(0)]],
threadgroup float* local_sums [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]], uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]], uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
threadgroup float local_inv_mean[1];
threadgroup float local_sums[SIMD_SIZE];
float acc = 0; float acc = 0;
x += gid * size_t(axis_size) + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS; w += w_stride * lid * N_READS;
@@ -85,15 +84,13 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps, constant float& eps,
constant uint& axis_size, constant uint& axis_size,
constant uint& w_stride, constant uint& w_stride,
threadgroup float* local_inv_mean [[threadgroup(0)]],
threadgroup float* local_sums [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]], uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]], uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]], uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
threadgroup float local_inv_mean[1];
threadgroup float local_sums[SIMD_SIZE];
float acc = 0; float acc = 0;
x += gid * size_t(axis_size) + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS; w += w_stride * lid * N_READS;
@@ -379,6 +376,8 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps, \ constant float& eps, \
constant uint& axis_size, \ constant uint& axis_size, \
constant uint& w_stride, \ constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \ uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \ uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
@@ -408,6 +407,8 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps, \ constant float& eps, \
constant uint& axis_size, \ constant uint& axis_size, \
constant uint& w_stride, \ constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \ uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \ uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \ uint lsize [[threads_per_threadgroup]], \

View File

@@ -2,6 +2,7 @@
#include <metal_math> #include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional, bool forward> template <typename T, bool traditional, bool forward>
void rope_single_impl( void rope_single_impl(

View File

@@ -1,16 +1,946 @@
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/backend/metal/kernels/sdpa_vector.h" #include "mlx/backend/metal/kernels/sdpa_vector.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
using namespace metal; using namespace metal;
using namespace mlx::steel;
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoaderFA {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC BlockLoaderFA(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
*((const device ReadVector*)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
METAL_FUNC void next(short n) {
src += n * tile_stride;
}
};
template <bool M_aligned, bool N_aligned, bool K_aligned>
struct LoopAlignment {};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType = float,
typename Epilogue = TransformNone<U, AccumType>>
struct BlockMMAFA {
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = 8 * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
// Strides of A, B along reduction axis
STEEL_CONST short simd_stride_a = {
transpose_a ? TM_stride : TM_stride * lda_tgp};
STEEL_CONST short simd_stride_b = {
transpose_b ? TN_stride * ldb_tgp : TN_stride};
// Jump between elements
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
// Offsets within threadgroup
const short tm;
const short tn;
short sm;
short sn;
ushort sid;
ushort slid;
short As_offset;
short Bs_offset;
/* Constructor */
METAL_FUNC BlockMMAFA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
// Determine thread position in simdgroup matrix
short qid = simd_lane_id / 4;
slid = simd_lane_id;
sid = simd_group_id;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Determine thread and simdgroup offset
As_offset =
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
Bs_offset =
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of 8
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += 8) {
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] =
static_cast<AccumType>(As[i * simd_stride_a + 0]);
Asimd[i].thread_elements()[1] =
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] =
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
Bsimd[j].thread_elements()[1] =
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
short j_serp = (i % 2) ? (TN - 1 - j) : j;
simdgroup_multiply_accumulate(
results[i * TN + j_serp],
Asimd[i],
Bsimd[j_serp],
results[i * TN + j_serp]);
}
}
// Progress to next simdgroup tile
As += tile_stride_a;
Bs += tile_stride_b;
}
}
METAL_FUNC void rescale_output(const threadgroup float* Corrections) {
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
short row = sm + tm + i * TM_stride;
float scale_value = Corrections[row];
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = results[i * TN + j].thread_elements();
// int offset = (i * TM_stride) * ldc + (j * TN_stride);
accum[0] *= scale_value;
accum[1] *= scale_value;
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* C, const int ldc) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
// Write out C
C[offset] = outs[0];
C[offset + 1] = outs[1];
}
}
}
METAL_FUNC void store_result_to_tgp_memory(
threadgroup U* C,
const int ldc,
short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
METAL_FUNC void
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
U outs[2] = {
epilogue_op.apply(accum[0], C[offset_c]),
epilogue_op.apply(accum[1], C[offset_c + fdc])};
// Write out D
D[offset_d] = outs[0];
D[offset_d + 1] = outs[1];
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
}
}
}
}
}
METAL_FUNC void clear_results() {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
results[i * TN + j] = simdgroup_matrix<AccumType, 8, 8>(0);
}
}
}
};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_q,
bool transpose_k,
bool transpose_v,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<U, AccumType>>
struct FastAttentionKernel {
STEEL_CONST short tgp_padding = 16 / sizeof(T);
STEEL_CONST short float_padding = 16 / sizeof(float);
STEEL_CONST short tgp_mem_size_q =
transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding);
STEEL_CONST short tgp_mem_size_k =
transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
STEEL_CONST short tgp_mem_size_v =
transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding);
// maxes, rowsums, rescale
STEEL_CONST short tgp_mem_size_corrections =
4 * (BM * sizeof(float) + float_padding);
STEEL_CONST bool share_kv_smem = transpose_k != transpose_v;
STEEL_CONST short tgp_mem_size = share_kv_smem
? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
tgp_mem_size_corrections
: tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
tgp_mem_size_corrections + tgp_mem_size_v;
STEEL_CONST short tgp_size = WM * WN * 32;
static_assert(transpose_q == false, "Expected Q not transposed.");
static_assert(transpose_k == true, "Expected K transposed.");
static_assert(transpose_v == false, "Expected V not transposed.");
static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested.");
using loader_q_t = BlockLoaderFA<
T,
transpose_q ? BK : BM,
transpose_q ? BM : BK,
transpose_q ? BM + tgp_padding : BK + tgp_padding,
!transpose_q,
tgp_size>;
using loader_k_t = BlockLoaderFA<
T,
transpose_k ? BN : BK,
transpose_k ? BK : BN,
transpose_k ? BK + tgp_padding : BN + tgp_padding,
transpose_k,
tgp_size>;
using loader_v_t = BlockLoaderFA<
T,
transpose_v ? BK : BN,
transpose_v ? BN : BK,
transpose_v ? BN + tgp_padding : BK + tgp_padding,
transpose_v,
tgp_size>;
using mma_qk_t = BlockMMAFA<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_q,
transpose_k,
transpose_q ? BM + tgp_padding : BK + tgp_padding,
transpose_k ? BK + tgp_padding : BN + tgp_padding,
AccumType,
Epilogue>;
using mma_sv_t = BlockMMAFA<
T,
U,
BM,
BK,
BN,
WM,
WN,
false,
transpose_v,
BN + tgp_padding,
BK + tgp_padding,
AccumType,
Epilogue>;
/* Main kernel function */
template <bool M_aligned, bool N_aligned, bool K_aligned_>
static METAL_FUNC void gemm_loop(
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
const int gemm_k_iterations,
thread loader_k_t& loader_b,
thread mma_qk_t& mma_op,
thread const short& tgp_bm,
thread const short& tgp_bn,
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
// Appease the compiler
(void)l;
(void)tgp_bm;
short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
// not valid for gemm_k_iterations > 1 (so, BK == d_k)
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
}
static METAL_FUNC void initialize_corrections(
threadgroup float* C,
uint simd_lane_id,
uint simd_group_id) {
if (simd_group_id == 0) {
threadgroup float* maxes = C;
threadgroup float* sums = C + (BM + float_padding);
threadgroup float* o_rescale = sums + (BM + float_padding);
threadgroup float* output_rescale = o_rescale + (BM + float_padding);
if (simd_lane_id < BM) {
maxes[simd_lane_id] = -INFINITY; // m_i
sums[simd_lane_id] = 0.f; // l_i
o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new)
output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i
}
}
}
static METAL_FUNC void rescale_ss(
threadgroup T* Ss,
threadgroup float* Corrections,
uint simd_group_id,
uint simd_lane_id,
short2 local_blocks,
float alpha) {
if (simd_group_id == 0) {
short row_offset = BM + float_padding;
threadgroup float* maxes = Corrections;
threadgroup float* sums = Corrections + row_offset;
threadgroup float* o_rescale = sums + row_offset;
threadgroup float* output_scales = o_rescale + row_offset;
if (simd_lane_id < uint(local_blocks.y)) {
float m_i_old = maxes[simd_lane_id];
float l_i_old = sums[simd_lane_id];
float m_i_new = m_i_old;
float l_i_new = l_i_old;
short offset = simd_lane_id * (BN + tgp_padding);
float m_ij = -INFINITY;
for (short j = 0; j < local_blocks.x; j++) {
float val = alpha * float(Ss[offset + j]);
m_ij = max(m_ij, val);
}
m_i_new = max(m_ij, m_i_new);
float rowsum = 0.f; // lij
for (short j = 0; j < local_blocks.x; j++) {
float val = alpha * float(Ss[offset + j]);
float P_i_j = exp(val - m_ij);
rowsum += P_i_j;
P_i_j = P_i_j * exp(m_ij - m_i_new);
Ss[offset + j] = T(P_i_j);
}
l_i_new =
exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum;
maxes[simd_lane_id] = m_i_new;
sums[simd_lane_id] = l_i_new;
float rescale = l_i_old * exp(m_i_old - m_i_new);
o_rescale[simd_lane_id] = rescale;
output_scales[simd_lane_id] = 1.0 / l_i_new;
}
}
}
/* Main kernel function */
static METAL_FUNC void run(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device U* O [[buffer(3)]],
const constant MLXFastAttentionParams* params [[buffer(4)]],
threadgroup T* Qs [[threadgroup(0)]],
threadgroup T* Ks [[threadgroup(1)]],
threadgroup T* Ss [[threadgroup(2)]],
threadgroup T* Vs [[threadgroup(3)]],
threadgroup float* Corrections [[threadgroup(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in Q, O; and head in K, V.
const int c_row = tid_y * BM;
Q += transpose_q ? c_row : c_row * params->ldq;
thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id);
short tgp_bm = min(BM, params->M - c_row);
short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
loader_q.load_safe(tile_dims_Q);
initialize_corrections(Corrections, simd_lane_id, simd_group_id);
O += c_row * params->ldo;
// Prepare threadgroup mma operation
thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id);
thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id);
thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id);
thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id);
for (short n_block = 0; n_block < params->gemm_n_iterations_aligned;
n_block++) {
short c_col = BN;
// Prepare threadgroup loading operations
short gemm_k_iterations = params->gemm_k_iterations_aligned;
short tgp_bn_qk = min(BN, params->N - c_col * n_block);
threadgroup_barrier(mem_flags::mem_none);
///////////////////////////////////////////////////////////////////////////////
{ // Loop over K - unaligned case
if (tgp_bm == BM && tgp_bn_qk == BN) {
gemm_loop<true, true, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
} else if (tgp_bn_qk == BN) {
gemm_loop<false, true, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
} else if (tgp_bm == BM) {
gemm_loop<true, false, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
} else {
gemm_loop<false, false, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
}
}
mma_qk_op.store_result_to_tgp_memory(
Ss, BN + tgp_padding, short2(BN, BM));
threadgroup_barrier(mem_flags::mem_threadgroup);
rescale_ss(
Ss,
Corrections,
simd_group_id,
simd_lane_id,
short2(tgp_bn_qk, tgp_bm),
params->alpha);
loader_v.load_safe(short2(BK, tgp_bn_qk));
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float* o_scales = Corrections + 2 * (BM + float_padding);
mma_softmax_sv_op.rescale_output(o_scales);
mma_softmax_sv_op.mma(Ss, Vs);
threadgroup float* final_output_scales =
Corrections + 3 * (BM + float_padding);
mma_softmax_sv_op.rescale_output(final_output_scales);
loader_v.next();
loader_k.next(BN);
mma_qk_op.clear_results();
}
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm));
}
};
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_q,
bool transpose_k,
bool transpose_v,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device T* O [[buffer(3)]],
const constant MLXFastAttentionParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using attention_kernel = FastAttentionKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_q,
transpose_k,
transpose_v,
MN_aligned,
K_aligned>;
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* Q_bstrides = batch_strides;
const constant size_t* KV_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim);
Q += batch_offsets.x;
K += batch_offsets.y;
V += batch_offsets.y;
} else {
Q += params->batch_stride_q * tid.z;
K += params->batch_stride_k * tid.z;
V += params->batch_stride_v * tid.z;
}
// same shape as input
O += params->batch_stride_o * tid.z;
threadgroup T Qs[attention_kernel::tgp_mem_size_q];
threadgroup T Ss[attention_kernel::tgp_mem_size_s];
threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections];
if (attention_kernel::share_kv_smem) {
threadgroup T Ks[attention_kernel::tgp_mem_size_k];
threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v];
attention_kernel::run(
Q,
K,
V,
O,
params,
Qs,
Ks,
Ss,
Vs,
Corrections,
simd_lane_id,
simd_group_id,
tid,
lid);
} else {
threadgroup T Ks[attention_kernel::tgp_mem_size_k];
threadgroup T Vs[attention_kernel::tgp_mem_size_v];
attention_kernel::run(
Q,
K,
V,
O,
params,
Qs,
Ks,
Ss,
Vs,
Corrections,
simd_lane_id,
simd_group_id,
tid,
lid);
}
}
// clang-format off // clang-format off
// SDPA full instantiations
#define instantiate_fast_inference_self_attention_kernel( \
itype, otype, bm, bn, bk, wm, wn) \
template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \
"_itype_" #itype)]] [[kernel]] void \
attention<itype, bm, bn, bk, wm, wn, false, true, false, false, true>( \
const device itype* Q [[buffer(0)]], \
const device itype* K [[buffer(1)]], \
const device itype* V [[buffer(2)]], \
device otype* O [[buffer(3)]], \
const constant MLXFastAttentionParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
instantiate_fast_inference_self_attention_kernel(
float,
float,
16,
16,
64,
2,
2);
instantiate_fast_inference_self_attention_kernel(
float,
float,
16,
16,
128,
2,
2);
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2);
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
// SDPA vector instantiations // SDPA vector instantiations
#define instantiate_sdpa_vector(type, head_dim) \ #define instantiate_sdpa_vector(type, head_dim) \
instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) \ template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \
instantiate_kernel("sdpa_vector_2pass_1_" #type "_" #head_dim, sdpa_vector_2pass_1, type, head_dim) \ [[kernel]] void sdpa_vector<type, head_dim>( \
instantiate_kernel("sdpa_vector_2pass_2_" #type "_" #head_dim, sdpa_vector_2pass_2, type, head_dim) const device type* queries [[buffer(0)]], \
const device type* keys [[buffer(1)]], \
const device type* values [[buffer(2)]], \
device type* out [[buffer(3)]], \
const constant int& gqa_factor, \
const constant int& N, \
const constant size_t& k_stride, \
const constant size_t& v_stride, \
const constant float& scale, \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_sdpa_vector_heads(type) \ #define instantiate_sdpa_vector_heads(type) \
instantiate_sdpa_vector(type, 64) \ instantiate_sdpa_vector(type, 64) \

View File

@@ -0,0 +1,42 @@
//
// scaled_dot_product_attention_params.h
// mlx
#pragma once
struct MLXFastAttentionParams {
const int M;
const int N;
const int K;
const int ldq; // ldq == ldo
const int ldk;
const int ldv;
const int lds;
const int ldo;
const int tiles_n;
const int tiles_m;
const int batch_stride_q;
const int batch_stride_k;
const int batch_stride_v;
const int batch_stride_o;
const int swizzle_log;
const int gemm_n_iterations_aligned;
const int gemm_k_iterations_aligned;
const int gemm_sv_m_block_iterations;
const int batch_ndim;
const float alpha;
};
struct MLXScaledDotProductAttentionParams {
// Associated dimensions & transposition information
const uint QUERY_SEQUENCE_LENGTH = 1;
const uint N_Q_HEADS = 32;
const uint N_KV_HEADS = 32;
const uint KV_TILES = 1;
const float INV_ALPHA = 0.08838834764831843f;
};

View File

@@ -10,8 +10,7 @@ template <
typename Op, typename Op,
int NIDX, int NIDX,
bool UPD_ROW_CONTIG, bool UPD_ROW_CONTIG,
int NWORK, int NWORK>
typename LocT>
METAL_FUNC void scatter_impl( METAL_FUNC void scatter_impl(
const device T* updates, const device T* updates,
device mlx_atomic<T>* out, device mlx_atomic<T>* out,
@@ -29,31 +28,29 @@ METAL_FUNC void scatter_impl(
Op op; Op op;
auto ind_idx = gid.y * NWORK; auto ind_idx = gid.y * NWORK;
LocT out_offset = 0; size_t out_offset = 0;
if (upd_size > 1) { if (upd_size > 1) {
out_offset = elem_to_loc<size_t, LocT>( out_offset =
gid.x, upd_shape + indices.ndim, out_strides, out_ndim); elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
} }
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) { for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
LocT out_idx = out_offset; size_t out_idx = out_offset;
for (int i = 0; i < NIDX; ++i) { for (int i = 0; i < NIDX; ++i) {
auto idx_loc = indices.row_contiguous[i] auto idx_loc = indices.row_contiguous[i]
? ind_idx ? ind_idx
: elem_to_loc<size_t, LocT>( : elem_to_loc(
ind_idx, ind_idx,
&indices.shapes[indices.ndim * i], &indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i], &indices.strides[indices.ndim * i],
indices.ndim); indices.ndim);
auto ax = axes[i]; auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += out_idx += idx_val * out_strides[ax];
static_cast<LocT>(idx_val) * static_cast<LocT>(out_strides[ax]);
} }
auto upd_idx = ind_idx * static_cast<LocT>(upd_size) + gid.x; auto upd_idx = ind_idx * upd_size + gid.x;
if constexpr (!UPD_ROW_CONTIG) { if constexpr (!UPD_ROW_CONTIG) {
upd_idx = upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim);
elem_to_loc<size_t, LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
} }
op.atomic_update(out, updates[upd_idx], out_idx); op.atomic_update(out, updates[upd_idx], out_idx);
} }

View File

@@ -21,7 +21,8 @@ template <typename T, int D>
constexpr int BN = 32; constexpr int BN = 32;
constexpr int BD = 32; constexpr int BD = 32;
constexpr int elem_per_thread = D / BD; constexpr int elem_per_thread = D / BD;
constexpr int stride = BN * D;
const int stride = BN * D;
typedef float U; typedef float U;
@@ -83,6 +84,7 @@ template <typename T, int D>
keys += stride; keys += stride;
values += stride; values += stride;
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
// Each thread has a partial part of the output so we need to combine them. // Each thread has a partial part of the output so we need to combine them.
@@ -112,181 +114,3 @@ template <typename T, int D>
} }
} }
} }
template <typename T, int D>
[[kernel]] void sdpa_vector_2pass_1(
const device T* queries [[buffer(0)]],
const device T* keys [[buffer(1)]],
const device T* values [[buffer(2)]],
device float* out [[buffer(3)]],
device float* sums [[buffer(4)]],
device float* maxs [[buffer(5)]],
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_stride,
const constant size_t& v_stride,
const constant float& scale,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 8;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;
constexpr int stride = BN * D;
constexpr int blocks = 32;
typedef float U;
thread U q[elem_per_thread];
thread U k[elem_per_thread];
thread U o[elem_per_thread];
threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN];
threadgroup U sum_exp_scores[BN];
// Adjust positions
const int block_idx = tid.z;
const int head_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
simd_lid * elem_per_thread;
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
sums += head_idx * blocks + block_idx;
maxs += head_idx * blocks + block_idx;
// Read the query and 0 the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
q[i] = static_cast<U>(scale) * queries[i];
}
for (int i = 0; i < elem_per_thread; i++) {
o[i] = 0;
}
U max_score = -1e9;
U sum_exp_score = 0;
// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
}
// Move the pointers to the next kv
keys += blocks * stride;
values += blocks * stride;
}
// Each thread has a partial part of the output so we need to combine them.
// First let's communicate the max and sum_exp
if (simd_lid == 0) {
max_scores[simd_gid] = max_score;
sum_exp_scores[simd_gid] = sum_exp_score;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
sum_exp_score = simd_sum(sum_exp_score * factor);
// Write the sum and new max
if (simd_gid == 0) {
sums[0] = sum_exp_score;
maxs[0] = new_max;
}
// Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BN + simd_gid] =
o[i] * fast::exp(max_scores[simd_gid] - new_max);
threadgroup_barrier(mem_flags::mem_threadgroup);
// And write the output
if (simd_gid == 0) {
U output = outputs[simd_lid * BN];
for (int j = 1; j < BN; j++) {
output += outputs[simd_lid * BN + j];
}
out[i] = static_cast<T>(output);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
template <typename T, int D>
[[kernel]] void sdpa_vector_2pass_2(
const device float* partials [[buffer(0)]],
const device float* sums [[buffer(1)]],
const device float* maxs [[buffer(2)]],
device T* out [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;
constexpr int blocks = 32;
typedef float U;
thread U o[elem_per_thread];
threadgroup U outputs[BN * BD];
// Adjust positions
const int head_idx = tid.y;
partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
sums += head_idx * blocks;
maxs += head_idx * blocks;
out += head_idx * D + simd_gid * elem_per_thread;
// First everybody reads the max and sum_exp
U max_score = maxs[simd_lid];
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
U sum_exp_score = simd_sum(sums[simd_lid] * factor);
// Now read the block into registers and then use shared memory to transpose
// it
for (int i = 0; i < elem_per_thread; i++) {
o[i] = partials[i];
}
for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// And write the output
if (simd_lid == 0) {
for (int i = 0; i < elem_per_thread; i++) {
out[i] = static_cast<T>(o[i]);
}
}
}

View File

@@ -6,6 +6,8 @@
using namespace metal; using namespace metal;
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/softmax.h" #include "mlx/backend/metal/kernels/softmax.h"

View File

@@ -3,6 +3,8 @@
#include <metal_stdlib> #include <metal_stdlib>
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/sort.h" #include "mlx/backend/metal/kernels/sort.h"

View File

@@ -1,296 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/attn/loader.h"
#include "mlx/backend/metal/kernels/steel/attn/mma.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel class
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <bool M_aligned, bool N_aligned, bool K_aligned>
struct LoopAlignment {};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<U, AccumType>>
struct GEMMKernel {
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
STEEL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
STEEL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
STEEL_CONST short tgp_size = WM * WN * 32;
using loader_a_t = BlockLoader<
T,
transpose_a ? BK : BM,
transpose_a ? BM : BK,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
!transpose_a,
tgp_size>;
using loader_b_t = BlockLoader<
T,
transpose_b ? BN : BK,
transpose_b ? BK : BN,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
transpose_b,
tgp_size>;
using mma_t = BlockMMA<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
template <bool M_aligned, bool N_aligned, bool K_aligned_>
static METAL_FUNC void gemm_loop(
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
const int gemm_k_iterations,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
thread mma_t& mma_op,
thread const short& tgp_bm,
thread const short& tgp_bn,
thread const short& lbk,
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
// Appease the compiler
(void)l;
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned_) {
threadgroup_barrier(mem_flags::mem_threadgroup);
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device U* D [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_loop<true, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result(D, params->ldd);
return;
} else if (tgp_bn == BN) {
gemm_loop<false, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
} else if (tgp_bm == BM) {
gemm_loop<true, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
} else {
gemm_loop<false, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@@ -1,349 +0,0 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
constant bool align_Q [[function_constant(200)]];
constant bool align_K [[function_constant(201)]];
template <typename T>
struct TransformScale {
T scale;
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
METAL_FUNC T apply(T x) const {
return scale * x;
}
};
struct MaxOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return metal::max(x, y);
}
};
struct SumOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x + y;
}
};
struct MulOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x * y;
}
};
struct SubOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x - y;
}
};
struct ExpSubOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return fast::exp(x - y);
}
};
struct DivOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x / y;
}
};
// clang-format off
template <
typename T,
int BQ,
int BK,
int BD,
int WM,
int WN,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device T* O [[buffer(3)]],
const constant AttnParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
// Pacifying compiler
(void)lid;
// Move to correct block
ulong3 tidl{tid.x, tid.y, tid.z};
Q += tidl.z * params->Q_strides[0] + // Batch
tidl.y * params->Q_strides[1] + // Head
tidl.x * BQ * params->Q_strides[2]; // Seqeunce
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
K += tidl.z * params->K_strides[0] + // Batch
kv_head_idx * params->K_strides[1]; // Head
V += tidl.z * params->V_strides[0] + // Batch
kv_head_idx * params->V_strides[1]; // Head
O += tidl.z * params->O_strides[0] + // Batch
tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Seqeunce
// Prepare threadgroup memory
constexpr short padQ = 0; // 16 / sizeof(T);
constexpr short padK = 0; // 16 / sizeof(T);
constexpr short padV = 0; // 16 / sizeof(T);
constexpr short LDQ_tgp = BD + padQ;
constexpr short LDK_tgp = BK + padK;
constexpr short LDV_tgp = BD + padV;
threadgroup T Qs[BQ * (BD + padQ)];
threadgroup T Ks[(BK + padK) * BD];
threadgroup T Vs[BK * (BD + padV)];
// Prepare block loaders
using QBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BQ,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ LDQ_tgp,
/* short kDstStrCol = */ 1,
/* short reduction_dim = */ 1,
/* short tgp_size = */ WM * WN * 32>;
// K is loaded in transposed
using KBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BK,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ 1,
/* short kDstStrCol = */ LDK_tgp,
/* short reduction_dim = */ 0,
/* short tgp_size = */ WM * WN * 32>;
using VBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BK,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ LDV_tgp,
/* short kDstStrCol = */ 1,
/* short reduction_dim = */ 0,
/* short tgp_size = */ WM * WN * 32>;
QBlockLoader loader_q(
Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
KBlockLoader loader_k(
K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
VBlockLoader loader_v(
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
TransformScale<T> ts(static_cast<T>(params->scale));
// Prepare MMA tiles
constexpr short kFragSize = 8; // MMAFrag size
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
constexpr int kNWarps = WM * WN;
static_assert(
BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
// Q seq frags per warp
constexpr int TQ = BQ / (kNWarps * kFragSize);
// KV sequence frags (all warps load the same frags)
constexpr int TK = BK / kFragSize;
// HeadDim frags (all warps load the same frags)
constexpr int TD = BD / kFragSize;
static_assert(TQ == 1, "Check TQ");
MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;
MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;
MMATile<AccumType, TK, TD, MMAFrag_acc_t> Vtile;
MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;
Otile.clear();
// Prepare mma tile offsets
const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
const short sm = simd_coord.y;
const short sn = simd_coord.x;
const short tm = kFragSize * TQ * simd_group_id;
const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
const short Ks_offset = sm * LDK_tgp + sn;
const short Vs_offset = sm * LDV_tgp + sn;
constexpr short Qs_tile_stride = kFragSize;
constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load Q blocks apply scale
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
} else {
loader_q.load_unsafe();
}
loader_q.apply_inplace_op(ts);
// Init row reduction variables
constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
AccumType max_score[kRowsPT];
AccumType sum_score[kRowsPT] = {0};
// Init to -Inf
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::min;
}
// Loop over KV seq length
for (int kb = 0; kb < params->NK; kb++) {
// Load K block and apply scale
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!align_K && kb == (params->NK_aligned)) {
loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
} else {
loader_k.load_unsafe();
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do S = Q @ K.T
Stile.clear();
for (short dd = 0; dd < TD; dd++) {
simdgroup_barrier(mem_flags::mem_none);
Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
&Qs[Qs_offset + dd * Qs_tile_stride]);
Ktile.template load<T, 1, 1, LDK_tgp, 1>(
&Ks[Ks_offset + dd * Ks_tile_stride]);
simdgroup_barrier(mem_flags::mem_none);
tile_matmad(Stile, Qtile, Ktile, Stile);
}
// Mask out of length sequence
if (!align_K && kb == (params->NK_aligned)) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
const short lim = params->kL - params->NK_aligned * BK;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < stile_t::kTileCols; j++) {
short col_pos = sn + (j * stile_t::kFragCols);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
if ((col_pos + jj) >= lim) {
Stile.frag_at(i, j)[jj] = neg_inf;
}
}
}
}
}
simdgroup_barrier(mem_flags::mem_none);
// Load V blocks
if (!align_K && kb == (params->NK_aligned)) {
loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
} else {
loader_v.load_unsafe();
}
// Do softmax
// Temp variables
AccumType new_max[kRowsPT];
AccumType factor[kRowsPT];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
new_max[i] = max_score[i];
}
// Row max
Stile.template row_reduce<MaxOp>(new_max);
// exp(Si - rowmax(Si))
Stile.template row_bin_op<ExpSubOp>(new_max);
// Factor exp(rowmax(Si) - rowmax(Si-1))
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
factor[i] = fast::exp(max_score[i] - new_max[i]);
}
// Save max for next iteration
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = new_max[i];
}
// Row Sum
AccumType sum_score_tmp[kRowsPT] = {0};
Stile.template row_reduce<SumOp>(sum_score_tmp);
// Update norm
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
}
// Update O
Otile.template row_bin_op<MulOp>(factor);
// Load V into registers
threadgroup_barrier(mem_flags::mem_threadgroup);
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
simdgroup_barrier(mem_flags::mem_none);
// Do O = S @ V
tile_matmad(Otile, Stile, Vtile, Otile);
// Prepare for next iteration
loader_k.next();
loader_v.next();
}
// Normalize output
Otile.template row_bin_op<DivOp>(sum_score);
threadgroup_barrier(mem_flags::mem_none);
// Store results
O += (tm + sm) * params->O_strides[2] + sn;
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
auto dst_tile_dims =
short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
} else {
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
}
}

View File

@@ -1,31 +0,0 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
[[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
const device dtype* Q [[buffer(0)]], \
const device dtype* K [[buffer(1)]], \
const device dtype* V [[buffer(2)]], \
device dtype* O [[buffer(3)]],\
const constant AttnParams* params [[buffer(4)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_attn_shapes_helper(iname, itype) \
instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
instantiate_attn_shapes_helper(float16, half);
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
instantiate_attn_shapes_helper(float32, float);
// clang-format on

View File

@@ -1,264 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/defines.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoader {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC BlockLoader(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
*((const device ReadVector*)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
template <int R, int C>
struct CShape {
STEEL_CONST int kRows = R;
STEEL_CONST int kCols = C;
};
template <
typename T,
short BROWS,
short BCOLS,
short kDstStrRow,
short kDstStrCol,
short reduction_dim,
short tgp_size,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoaderT {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
/* Constructor */
METAL_FUNC BlockLoaderT(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * kDstStrRow + bj * kDstStrCol),
src(src_ + bi * src_ld + bj) {}
/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] =
op.apply(dst[i * kDstStrRow + j * kDstStrCol]);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j];
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
} // namespace steel
} // namespace mlx

View File

@@ -1,726 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename RInt, typename CInt>
struct Shape2D {
RInt r;
CInt c;
Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {}
};
template <typename Shape, typename Layout>
struct Layout2D {
Shape shape;
Layout layout;
};
template <typename T, int kFragRows_, int kFragCols_>
struct BaseMMAFrag {
static_assert(
kFragRows_ == 8,
"Only 8 x 8 fragment matrices are currently supported");
static_assert(
kFragCols_ == 8,
"Only 8 x 8 fragment matrices are currently supported");
};
template <typename T>
struct BaseMMAFrag<T, 8, 8> {
STEEL_CONST int kFragRows = 8;
STEEL_CONST int kFragCols = 8;
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
STEEL_CONST int kElemRows = 1;
STEEL_CONST int kElemCols = 2;
static_assert(
kElemRows * kElemCols == kElemsPerFrag,
"MMAFrag shape is not consistent with MMAFrag size");
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
typedef metal::vec<T, kElemsPerFrag> frag_type;
typedef metal::vec<T, kElemRows> row_frag_type;
typedef metal::vec<T, kElemCols> col_frag_type;
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
[[thread_index_in_simdgroup]]) {
const short qid = simd_lane_id / 4;
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
return short2{fn, fm};
}
template <typename SrcPtrType, typename StrX, typename StrY>
METAL_FUNC static constexpr void
load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
}
}
}
template <
typename SrcPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void load_safe(
thread frag_type& dst,
SrcPtrType src,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[i * kElemCols + j] =
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
} else {
dst[i * kElemCols + j] = T(0);
}
}
}
}
template <typename DstPtrType, typename StrX, typename StrY>
METAL_FUNC static constexpr void
store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
}
}
}
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void store_safe(
const thread frag_type& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
METAL_FUNC static constexpr void mma(
thread frag_type& D,
thread frag_type& A,
thread frag_type& B,
thread frag_type& C) {
mat_type D_mat;
mat_type A_mat;
mat_type B_mat;
mat_type C_mat;
reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
mma(D_mat, A_mat, B_mat, C_mat);
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
}
METAL_FUNC static constexpr void mma(
thread mat_type& D,
thread mat_type& A,
thread mat_type& B,
thread mat_type& C) {
simdgroup_multiply_accumulate(D, A, B, C);
}
template <typename Op>
METAL_FUNC static constexpr void row_reduce(
thread const frag_type& inp_vals,
thread T* reduced_vals) {
T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
}
template <typename Op>
METAL_FUNC static constexpr void row_bin_op(
thread frag_type& inp_vals,
thread T* row_vals) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
inp_vals[i * kElemCols + j] =
Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
}
}
}
};
template <
typename T,
int kTileRows_,
int kTileCols_,
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
struct MMATile {
using MMAFrag_t = MMAFrag_;
using elem_type = T;
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
STEEL_CONST int kTileRows = kTileRows_;
STEEL_CONST int kTileCols = kTileCols_;
STEEL_CONST int kRows = kTileRows * kFragRows;
STEEL_CONST int kCols = kTileCols * kFragCols;
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;
STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;
typedef typename MMAFrag_t::mat_type mat_type;
typedef typename MMAFrag_t::frag_type frag_type;
frag_type val_frags[kNumFrags] = {frag_type(0)};
METAL_FUNC MMATile() thread {}
METAL_FUNC constexpr void clear() {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kNumFrags; ++i) {
val_frags[i] = frag_type(0);
}
}
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
return val_frags[i * kTileCols + j];
}
METAL_FUNC constexpr const thread frag_type& frag_at(
const short i,
const short j) const {
return val_frags[i * kTileCols + j];
}
METAL_FUNC mat_type mat_at(const short i, const short j) {
mat_type val_mat;
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
}
return val_mat;
}
METAL_FUNC thread elem_type* elems() {
return reinterpret_cast<thread elem_type*>(val_frags);
}
METAL_FUNC const thread elem_type* elems() const {
return reinterpret_cast<const thread elem_type*>(val_frags);
}
template <typename Op>
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::template row_reduce<Op>(
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
}
}
}
template <typename Op>
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::template row_bin_op<Op>(
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
}
}
}
template <typename U, int w_x, int w_y, int str_x, int str_y>
METAL_FUNC void load(const threadgroup U* src) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::load(
frag_at(i, j),
&(
src[(i * kFragRows) * w_x * str_x +
(j * kFragCols) * w_y * str_y]),
Int<str_x>{},
Int<str_y>{});
}
}
}
template <typename U, int w_x, int w_y, int str_x, int str_y>
METAL_FUNC void store(threadgroup U* dst) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::store(
frag_at(i, j),
&(
dst[(i * kFragRows) * w_x * str_x +
(j * kFragCols) * w_y * str_y]),
Int<str_x>{},
Int<str_y>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void load(const device U* src, const int ld) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::load(
frag_at(i, j),
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
ld,
Int<1>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void store(device U* dst, const int ld) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::store(
frag_at(i, j),
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
ld,
Int<1>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::load_safe(
frag_at(i, j),
src,
ld,
Int<1>{},
src_tile_dims.y,
src_tile_dims.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::store_safe(
frag_at(i, j),
dst,
ld,
Int<1>{},
dst_tile_dims.y,
dst_tile_dims.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
};
template <typename T, typename U, int M, int N, int K>
METAL_FUNC void tile_matmad(
thread MMATile<T, M, N>& D,
thread MMATile<U, M, K>& A,
thread MMATile<U, K, N>& B,
thread MMATile<T, M, N>& C) {
STEEL_PRAGMA_UNROLL
for (short k = 0; k < K; ++k) {
STEEL_PRAGMA_UNROLL
for (short m = 0; m < M; ++m) {
STEEL_PRAGMA_UNROLL
for (short n = 0; n < N; ++n) {
short n_serp = (m % 2) ? (N - 1 - n) : n;
MMATile<T, M, N>::MMAFrag_t::mma(
D.frag_at(m, n_serp),
A.frag_at(m, k),
B.frag_at(k, n_serp),
C.frag_at(m, n_serp));
}
}
}
}
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType = float,
typename Epilogue = TransformNone<U, AccumType>>
struct BlockMMA {
// MMAFrag size
STEEL_CONST short kFragSize = 8;
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = kFragSize * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = kFragSize * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
// Threadgroup A strides
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
// Threadgroup B strides
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
// Threadgroup strides along K
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
// Simdgroup matrices
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
// Offsets within threadgroup
short sm;
short sn;
short As_offset;
short Bs_offset;
/* Constructor */
METAL_FUNC BlockMMA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
// Determine thread position in simdgroup matrix
short tm = kFragSize * (simd_group_id / WN);
short tn = kFragSize * (simd_group_id % WN);
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
sm = simd_coord.y;
sn = simd_coord.x;
// Determine thread and simdgroup offset
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
sm += tm;
sn += tn;
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of kFragSize
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += kFragSize) {
simdgroup_barrier(mem_flags::mem_none);
Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
simdgroup_barrier(mem_flags::mem_none);
Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
simdgroup_barrier(mem_flags::mem_none);
tile_matmad(Ctile, Atile, Btile, Ctile);
// Progress to next simdgroup tile
As += tile_stride_a;
Bs += tile_stride_b;
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* D, const int ldd) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
// Adjust for simdgroup and thread location
D += sm * ldd + sn;
Ctile.template store<U, WM, WN>(D, ldd);
}
METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
// Adjust for simdgroup and thread location
D += sm * ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
}
/* Apply epilogue */
template <typename UnaryEpilogue>
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue(
const device U* C,
const int ldc,
const int fdc,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue_safe(
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
// Read C
U c_elems[kelems] = {0};
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
c_elems[k] = C[offset_c + k * fdc];
}
}
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
D[offset_d + k] =
epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@@ -1,36 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
///////////////////////////////////////////////////////////////////////////////
// Attn param classes
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
struct AttnParams {
int B; ///< Batch Size
int H; ///< Heads
int D; ///< Head Dim
int qL; ///< Query Sequence Length
int kL; ///< Key Sequence Length
int gqa_factor; ///< Group Query factor
float scale; ///< Attention scale
int NQ; ///< Number of query blocks
int NK; ///< Number of key/value blocks
int NQ_aligned; ///< Number of full query blocks
int NK_aligned; ///< Number of full key/value blocks
size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
size_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
size_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
size_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
};
} // namespace steel
} // namespace mlx

View File

@@ -1,71 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
///////////////////////////////////////////////////////////////////////////////
// Transforms and Epilogues
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
static METAL_FUNC OutT apply(InT x, OutT) {
return static_cast<OutT>(x);
}
};
template <typename OutT, typename InT>
struct TransformAdd {
TransformAdd(const float, const float) {}
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
static METAL_FUNC OutT apply(InT x, OutT c) {
return static_cast<OutT>(x) + c;
}
};
template <typename OutT, typename InT>
struct TransformAxpby {
const float alpha;
const float beta;
TransformAxpby(const float alpha_, const float beta_)
: alpha(alpha_), beta(beta_) {}
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
struct BlockSwizzle {
static METAL_FUNC int2
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
const int tid_x = (tid.x) >> swizzle_log;
const int tid_y =
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
return int2(tid_x, tid_y);
}
};
} // namespace steel
} // namespace mlx

View File

@@ -5,7 +5,7 @@
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h" #include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"

View File

@@ -5,7 +5,7 @@
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/steel/utils.h"

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"

View File

@@ -1,7 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h"

View File

@@ -385,9 +385,9 @@ struct BlockMMA {
STEEL_CONST short TN_stride = kFragSize * WN; STEEL_CONST short TN_stride = kFragSize * WN;
// Warp tile size along M // Warp tile size along M
STEEL_CONST short TM = BM / (kFragSize * WM); STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N // Warp tile size along N
STEEL_CONST short TN = BN / (kFragSize * WN); STEEL_CONST short TN = BN / TN_stride;
// Threadgroup A strides // Threadgroup A strides
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M

View File

@@ -32,13 +32,13 @@ template <typename T, typename Op>
constant const size_t& b_strides, constant const size_t& b_strides,
constant const size_t& c_strides, constant const size_t& c_strides,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_strides); auto a_idx = elem_to_loc_1(index, a_strides);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_strides); auto b_idx = elem_to_loc_1(index, b_strides);
auto c_idx = elem_to_loc_1<size_t, uint>(index, c_strides); auto c_idx = elem_to_loc_1(index, c_strides);
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]); d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
} }
template <typename T, typename Op, typename IdxT = size_t> template <typename T, typename Op>
[[kernel]] void ternary_g_nd2( [[kernel]] void ternary_g_nd2(
device const bool* a, device const bool* a,
device const T* b, device const T* b,
@@ -49,14 +49,14 @@ template <typename T, typename Op, typename IdxT = size_t>
constant const size_t c_strides[2], constant const size_t c_strides[2],
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides); auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides); auto b_idx = elem_to_loc_2(index, b_strides);
auto c_idx = elem_to_loc_2<size_t, IdxT>(index, c_strides); auto c_idx = elem_to_loc_2(index, c_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
} }
template <typename T, typename Op, typename IdxT = size_t> template <typename T, typename Op>
[[kernel]] void ternary_g_nd3( [[kernel]] void ternary_g_nd3(
device const bool* a, device const bool* a,
device const T* b, device const T* b,
@@ -67,14 +67,15 @@ template <typename T, typename Op, typename IdxT = size_t>
constant const size_t c_strides[3], constant const size_t c_strides[3],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides); auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides); auto b_idx = elem_to_loc_3(index, b_strides);
auto c_idx = elem_to_loc_3<size_t, IdxT>(index, c_strides); auto c_idx = elem_to_loc_3(index, c_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
} }
template <typename T, typename Op, int N = 1, typename IdxT = size_t> template <typename T, typename Op, int N = 1>
[[kernel]] void ternary_g( [[kernel]] void ternary_g(
device const bool* a, device const bool* a,
device const T* b, device const T* b,
@@ -87,7 +88,7 @@ template <typename T, typename Op, int N = 1, typename IdxT = size_t>
constant const int& ndim, constant const int& ndim,
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_3_nd<IdxT>( auto idx = elem_to_loc_3_nd(
{N * index.x, index.y, index.z}, {N * index.x, index.y, index.z},
shape, shape,
a_strides, a_strides,
@@ -95,10 +96,11 @@ template <typename T, typename Op, int N = 1, typename IdxT = size_t>
c_strides, c_strides,
ndim); ndim);
auto xshape = shape[ndim - 1]; auto xshape = shape[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); size_t out_idx =
IdxT a_xstride = a_strides[ndim - 1]; N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
IdxT b_xstride = b_strides[ndim - 1]; auto a_xstride = a_strides[ndim - 1];
IdxT c_xstride = c_strides[ndim - 1]; auto b_xstride = b_strides[ndim - 1];
auto c_xstride = c_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]); d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]);
idx.x += a_xstride; idx.x += a_xstride;

View File

@@ -4,20 +4,18 @@
#include <metal_math> #include <metal_math>
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/ternary_ops.h" #include "mlx/backend/metal/kernels/ternary_ops.h"
#include "mlx/backend/metal/kernels/ternary.h" #include "mlx/backend/metal/kernels/ternary.h"
#define instantiate_ternary_all(op, tname, type) \ #define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \ instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \ instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \ instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op)
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_types(op) \ #define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \ instantiate_ternary_all(op, bool_, bool) \

View File

@@ -18,12 +18,7 @@ template <typename T, typename U, typename Op>
out[offset] = Op()(in[offset]); out[offset] = Op()(in[offset]);
} }
template < template <typename T, typename U, typename Op, int N = 1>
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
[[kernel]] void unary_g( [[kernel]] void unary_g(
device const T* in, device const T* in,
device U* out, device U* out,
@@ -32,11 +27,12 @@ template <
device const int& ndim, device const int& ndim,
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc<size_t, IdxT>( auto idx =
{N * index.x, index.y, index.z}, in_shape, in_strides, ndim); elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
auto xshape = in_shape[ndim - 1]; auto xshape = in_shape[ndim - 1];
IdxT xstride = in_strides[ndim - 1]; auto xstride = in_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
out[out_idx++] = Op()(in[idx]); out[out_idx++] = Op()(in[idx]);
idx += xstride; idx += xstride;

View File

@@ -5,13 +5,11 @@
#include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary_ops.h"
#include "mlx/backend/metal/kernels/unary.h" #include "mlx/backend/metal/kernels/unary.h"
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ #define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \ instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \ instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel( \ instantiate_kernel("gn4_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, uint) \
instantiate_kernel( \
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all_same(op, tname, type) \ #define instantiate_unary_all_same(op, tname, type) \
instantiate_unary_all(op, tname, tname, type, type) instantiate_unary_all(op, tname, tname, type, type)

View File

@@ -3,13 +3,7 @@
#pragma once #pragma once
#include <metal_math> #include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
// The correct bf16.h is included based on the metal version
// by giving the correct path to -I during compilation
// e.g. mlx/backend/metal/kernels/metal_3_0/ for Metal 3.0
#include "bf16.h"
#include "mlx/backend/metal/kernels/bf16_math.h"
#include "mlx/backend/metal/kernels/complex.h" #include "mlx/backend/metal/kernels/complex.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"
@@ -89,45 +83,44 @@ struct Limits<complex64_t> {
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Single Array with generic dims // Single Array with generic dims
template <typename StrideT, typename IdxT = StrideT> template <typename stride_t>
METAL_FUNC IdxT elem_to_loc( METAL_FUNC stride_t elem_to_loc(
uint elem, uint elem,
constant const int* shape, constant const int* shape,
constant const StrideT* strides, constant const stride_t* strides,
int ndim) { int ndim) {
IdxT loc = 0; stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) { for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]); loc += (elem % shape[i]) * strides[i];
elem /= shape[i]; elem /= shape[i];
} }
return loc; return loc;
} }
template <typename StrideT, typename IdxT = StrideT> template <typename stride_t>
METAL_FUNC IdxT elem_to_loc( METAL_FUNC stride_t elem_to_loc(
StrideT elem, stride_t elem,
constant const int* shape, constant const int* shape,
constant const StrideT* strides, constant const stride_t* strides,
int ndim) { int ndim) {
IdxT loc = 0; stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) { for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]); loc += (elem % shape[i]) * strides[i];
elem /= shape[i]; elem /= shape[i];
} }
return loc; return loc;
} }
// Non templated version to handle arbitrary dims // Non templated version to handle arbitrary dims
template <typename StrideT, typename IdxT = StrideT> template <typename stride_t>
METAL_FUNC IdxT elem_to_loc( METAL_FUNC stride_t elem_to_loc(
uint3 elem, uint3 elem,
constant const int* shape, constant const int* shape,
constant const StrideT* strides, constant const stride_t* strides,
int ndim) { int ndim) {
IdxT loc = stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
for (int d = ndim - 3; d >= 0; --d) { for (int d = ndim - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * IdxT(strides[d]); loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d]; elem.z /= shape[d];
} }
return loc; return loc;
@@ -136,65 +129,61 @@ METAL_FUNC IdxT elem_to_loc(
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Single Array with fixed N dims // Single Array with fixed N dims
template <typename StrideT, typename IdxT = StrideT> template <typename stride_t>
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT& stride) { METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
return elem * IdxT(stride); return elem * stride;
} }
template <typename StrideT, typename IdxT = StrideT> template <typename stride_t>
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2]) { METAL_FUNC stride_t
return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]); elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
return elem.x * strides[1] + elem.y * strides[0];
} }
template <typename StrideT, typename IdxT = StrideT> template <typename stride_t>
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) { METAL_FUNC stride_t
return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) + elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
elem.z * IdxT(strides[0]); return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
} }
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Multiple Arrays with generic dims // Multiple Arrays with generic dims
template <typename StrideT, typename IdxT = StrideT> template <typename stride_t>
METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd( METAL_FUNC ulong2 elem_to_loc_2_nd(
uint3 elem, uint3 elem,
constant const int* shape, constant const int* shape,
constant const StrideT* a_strides, constant const stride_t* a_strides,
constant const StrideT* b_strides, constant const stride_t* b_strides,
int ndim) { int ndim) {
vec<IdxT, 2> loc = { ulong2 loc = {
IdxT( ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
elem.x * IdxT(a_strides[ndim - 1]) + ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
IdxT(
elem.x * IdxT(b_strides[ndim - 1]) +
elem.y * IdxT(b_strides[ndim - 2]))};
for (int d = ndim - 3; d >= 0; --d) { for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d]; uint l = elem.z % shape[d];
loc.x += l * IdxT(a_strides[d]); loc.x += l * a_strides[d];
loc.y += l * IdxT(b_strides[d]); loc.y += l * b_strides[d];
elem.z /= shape[d]; elem.z /= shape[d];
} }
return loc; return loc;
} }
template <typename IdxT = size_t> METAL_FUNC ulong3 elem_to_loc_3_nd(
METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
uint3 elem, uint3 elem,
constant const int* shape, constant const int* shape,
constant const size_t* a_strides, constant const size_t* a_strides,
constant const size_t* b_strides, constant const size_t* b_strides,
constant const size_t* c_strides, constant const size_t* c_strides,
int ndim) { int ndim) {
vec<IdxT, 3> loc = { ulong3 loc = {
elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]), elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]), elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2],
elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])}; elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]};
for (int d = ndim - 3; d >= 0; --d) { for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d]; uint l = elem.z % shape[d];
loc.x += l * IdxT(a_strides[d]); loc.x += l * a_strides[d];
loc.y += l * IdxT(b_strides[d]); loc.y += l * b_strides[d];
loc.z += l * IdxT(c_strides[d]); loc.z += l * c_strides[d];
elem.z /= shape[d]; elem.z /= shape[d];
} }
return loc; return loc;
@@ -204,21 +193,16 @@ METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
// Elem to loc in a loop utils // Elem to loc in a loop utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <int DIM, typename OffsetT = size_t, bool General = true> template <int dim, typename offset_t = size_t>
struct LoopedElemToLoc { struct looped_elem_to_loc {
int dim; looped_elem_to_loc<dim - 1, offset_t> inner_looper;
LoopedElemToLoc<DIM - 1, OffsetT, General> inner_looper; offset_t offset{0};
OffsetT offset{0};
int index{0}; int index{0};
LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
void next(const constant int* shape, const constant size_t* strides) { void next(const constant int* shape, const constant size_t* strides) {
if (dim == 0) {
return;
}
index++; index++;
offset += OffsetT(strides[dim - 1]); offset += strides[dim - 1];
if (index >= shape[dim - 1]) { if (index >= shape[dim - 1]) {
index = 0; index = 0;
inner_looper.next(shape, strides); inner_looper.next(shape, strides);
@@ -227,21 +211,13 @@ struct LoopedElemToLoc {
} }
void next(int n, const constant int* shape, const constant size_t* strides) { void next(int n, const constant int* shape, const constant size_t* strides) {
if (dim == 0) {
return;
}
index += n; index += n;
offset += n * OffsetT(strides[dim - 1]); offset += n * strides[dim - 1];
if (index >= shape[dim - 1]) { if (index >= shape[dim - 1]) {
int extra = index - shape[dim - 1]; int extra = index - shape[dim - 1];
if (extra >= shape[dim - 1]) {
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
extra = extra % shape[dim - 1];
} else {
inner_looper.next(shape, strides);
}
index = 0; index = 0;
inner_looper.next(shape, strides);
offset = inner_looper.offset; offset = inner_looper.offset;
if (extra > 0) { if (extra > 0) {
next(extra, shape, strides); next(extra, shape, strides);
@@ -249,61 +225,44 @@ struct LoopedElemToLoc {
} }
} }
OffsetT location() { offset_t
location(offset_t, const constant int*, const constant size_t*, int) {
return offset; return offset;
} }
}; };
template <typename OffsetT> template <typename offset_t>
struct LoopedElemToLoc<1, OffsetT, true> { struct looped_elem_to_loc<1, offset_t> {
int dim; offset_t offset{0};
OffsetT offset{0};
uint index{0};
LoopedElemToLoc(int dim) : dim(dim) {}
void next(const constant int* shape, const constant size_t* strides) {
index++;
if (dim > 1) {
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
} else {
offset += OffsetT(strides[0]);
}
}
void next(int n, const constant int* shape, const constant size_t* strides) {
index += n;
if (dim > 1) {
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
} else {
offset = index * OffsetT(strides[0]);
}
}
OffsetT location() {
return offset;
}
};
template <typename OffsetT>
struct LoopedElemToLoc<1, OffsetT, false> {
OffsetT offset{0};
LoopedElemToLoc(int) {}
void next(const constant int*, const constant size_t* strides) { void next(const constant int*, const constant size_t* strides) {
offset += OffsetT(strides[0]); offset += strides[0];
} }
void next(int n, const constant int*, const constant size_t* strides) { void next(int n, const constant int*, const constant size_t* strides) {
offset += n * OffsetT(strides[0]); offset += n * strides[0];
} }
OffsetT location() { offset_t
location(offset_t, const constant int*, const constant size_t*, int) {
return offset; return offset;
} }
}; };
template <typename offset_t>
struct looped_elem_to_loc<0, offset_t> {
void next(const constant int*, const constant size_t*) {}
void next(int, const constant int*, const constant size_t*) {}
offset_t location(
offset_t idx,
const constant int* shape,
const constant size_t* strides,
int ndim) {
return elem_to_loc(idx, shape, strides, ndim);
}
};
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Calculation utils // Calculation utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

View File

@@ -11,12 +11,12 @@ SRC_DIR=$3
SRC_FILE=$4 SRC_FILE=$4
CFLAGS=$5 CFLAGS=$5
SRC_NAME=$(basename -- "${SRC_FILE}") SRC_NAME=$(basename -- "${SRC_FILE}")
JIT_INCLUDES=${SRC_DIR}/mlx/backend/metal/kernels/jit
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
mkdir -p "$OUTPUT_DIR" mkdir -p "$OUTPUT_DIR"
CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
CONTENT=$($CC -I "$SRC_DIR" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
cat << EOF > "$OUTPUT_FILE" cat << EOF > "$OUTPUT_FILE"
namespace mlx::core::metal { namespace mlx::core::metal {

View File

@@ -249,7 +249,7 @@ void steel_matmul_regular(
wm, wm,
wn); wn);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle // Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn; int tn = (N + bn - 1) / bn;
@@ -288,12 +288,12 @@ void steel_matmul_regular(
compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4); compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
compute_encoder.set_vector_bytes(batch_shape, 6); set_vector_bytes(compute_encoder, batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7); set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Record copies // Record copies
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
@@ -390,7 +390,7 @@ void steel_matmul(
wn, wn,
mn_aligned, mn_aligned,
k_aligned); k_aligned);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
int tn = (N + bn - 1) / bn; int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm; int tm = (M + bm - 1) / bm;
@@ -416,29 +416,34 @@ void steel_matmul(
compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(C_split, 2); compute_encoder.set_output_array(C_split, 2);
compute_encoder.set_bytes(params, 3); compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Do accum kernel // Do accum kernel
{ {
auto c_split_buf =
static_cast<const MTL::Resource*>(C_split.buffer().ptr());
const class MTL::Resource* const resources[1] = {c_split_buf};
compute_encoder->memoryBarrier(resources, 1);
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
type_to_name(C_split); type_to_name(C_split);
auto kernel = get_steel_gemm_splitk_accum_kernel( auto kernel = get_steel_gemm_splitk_accum_kernel(
d, kernel_name, C_split, out, false); d, kernel_name, C_split, out, false);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Set the arguments for the kernel // Set the arguments for the kernel
compute_encoder.set_input_array(C_split, 0); compute_encoder.set_input_array(C_split, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(split_k_partitions, 2); compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
compute_encoder.set_bytes(split_k_partition_stride, 3); compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
compute_encoder.set_bytes(N, 4); compute_encoder->setBytes(&N, sizeof(int), 4);
// Launch enough thread groups for each output // Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1); MTL::Size grid_dims = MTL::Size(N, M, 1);
auto group_dims = get_block_dims(N, M, 1); MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
@@ -620,7 +625,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size group_dims = MTL::Size(32, bn, bm);
@@ -630,16 +635,16 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(vec, 1); compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4); compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder.set_bytes(out_vector_len, 5); compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder.set_bytes(mat_ld, 6); compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder.set_bytes(batch_ndim, 9); compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
compute_encoder.set_vector_bytes(batch_shape, 10); set_vector_bytes(compute_encoder, batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides_vec, 11); set_vector_bytes(compute_encoder, batch_strides_vec, 11);
compute_encoder.set_vector_bytes(batch_strides_mat, 12); set_vector_bytes(compute_encoder, batch_strides_mat, 12);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
return; return;
@@ -817,7 +822,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size group_dims = MTL::Size(32, bn, bm);
@@ -828,23 +833,23 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(c, 2); compute_encoder.set_input_array(c, 2);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4); compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder.set_bytes(out_vector_len, 5); compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder.set_bytes(mat_ld, 6); compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder.set_bytes(alpha_, 7); compute_encoder->setBytes(&alpha_, sizeof(float), 7);
compute_encoder.set_bytes(beta_, 8); compute_encoder->setBytes(&beta_, sizeof(float), 8);
compute_encoder.set_bytes(batch_ndim, 9); compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
compute_encoder.set_vector_bytes(batch_shape, 10); set_vector_bytes(compute_encoder, batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides_vec, 11); set_vector_bytes(compute_encoder, batch_strides_vec, 11);
compute_encoder.set_vector_bytes(batch_strides_mat, 12); set_vector_bytes(compute_encoder, batch_strides_mat, 12);
compute_encoder.set_vector_bytes(C_batch_stride, 13); set_vector_bytes(compute_encoder, C_batch_stride, 13);
int bias_stride = c.strides()[c.ndim() - 1]; int bias_stride = c.strides()[c.ndim() - 1];
compute_encoder.set_bytes(bias_stride, 14); compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
return; return;
@@ -902,7 +907,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
mn_aligned, mn_aligned,
k_aligned); k_aligned);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
int tn = (N + bn - 1) / bn; int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm; int tm = (M + bm - 1) / bm;
@@ -928,8 +933,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(C_split, 2); compute_encoder.set_output_array(C_split, 2);
compute_encoder.set_bytes(params, 3); compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Do accum kernel // Do accum kernel
{ {
@@ -938,24 +943,25 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = get_steel_gemm_splitk_accum_kernel( auto kernel = get_steel_gemm_splitk_accum_kernel(
d, kernel_name, C_split, out, true); d, kernel_name, C_split, out, true);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Set the arguments for the kernel // Set the arguments for the kernel
compute_encoder.set_input_array(C_split, 0); compute_encoder.set_input_array(C_split, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(split_k_partitions, 2); compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
compute_encoder.set_bytes(split_k_partition_stride, 3); compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
compute_encoder.set_bytes(N, 4); compute_encoder->setBytes(&N, sizeof(int), 4);
compute_encoder.set_input_array(c, 5); compute_encoder.set_input_array(c, 5);
compute_encoder.set_bytes(ldc, 6); compute_encoder->setBytes(&ldc, sizeof(int), 6);
compute_encoder.set_bytes(fdc, 7); compute_encoder->setBytes(&fdc, sizeof(int), 7);
compute_encoder.set_bytes(alpha_, 8); compute_encoder->setBytes(&alpha_, sizeof(float), 8);
compute_encoder.set_bytes(beta_, 9); compute_encoder->setBytes(&beta_, sizeof(float), 9);
// Launch enough thread groups for each output // Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1); MTL::Size grid_dims = MTL::Size(N, M, 1);
auto group_dims = get_block_dims(N, M, 1); MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
@@ -1026,7 +1032,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
wm, wm,
wn); wn);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
int tn = (N + bn - 1) / bn; int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm; int tm = (M + bm - 1) / bm;
@@ -1077,13 +1083,13 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(c, 2); compute_encoder.set_input_array(c, 2);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(gemm_params, 4); compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
compute_encoder.set_bytes(params, 5); compute_encoder->setBytes(&params, sizeof(GEMMAddMMParams), 5);
compute_encoder.set_vector_bytes(batch_shape, 6); set_vector_bytes(compute_encoder, batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7); set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
@@ -1298,7 +1304,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
contiguous_kernel); contiguous_kernel);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size group_dims = MTL::Size(32, bn, bm);
@@ -1366,18 +1372,18 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(vec, 1); compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4); compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder.set_bytes(out_vector_len, 5); compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder.set_bytes(mat_ld, 6); compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder.set_bytes(batch_ndim, 9); compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
compute_encoder.set_vector_bytes(batch_shape, 10); set_vector_bytes(compute_encoder, batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides_vec, 11); set_vector_bytes(compute_encoder, batch_strides_vec, 11);
compute_encoder.set_vector_bytes(batch_strides_mat, 12); set_vector_bytes(compute_encoder, batch_strides_mat, 12);
compute_encoder.set_vector_bytes(mask_strides, 23); set_vector_bytes(compute_encoder, mask_strides, 23);
compute_encoder.set_vector_bytes(mask_batch_strides, 24); set_vector_bytes(compute_encoder, mask_batch_strides, 24);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
return; return;
@@ -1417,7 +1423,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
wn, wn,
mn_aligned, mn_aligned,
k_aligned); k_aligned);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle // Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn; int tn = (N + bn - 1) / bn;
@@ -1480,14 +1486,14 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4); compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
compute_encoder.set_vector_bytes(batch_shape, 6); set_vector_bytes(compute_encoder, batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7); set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.set_vector_bytes(mask_strides, 13); set_vector_bytes(compute_encoder, mask_strides, 13);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
@@ -1681,7 +1687,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size group_dims = MTL::Size(32, bn, bm);
@@ -1691,28 +1697,28 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(vec, 1); compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4); compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder.set_bytes(out_vector_len, 5); compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder.set_bytes(mat_ld, 6); compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder.set_bytes(batch_ndim, 9); compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
compute_encoder.set_vector_bytes(batch_shape, 10); set_vector_bytes(compute_encoder, batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides, 11); set_vector_bytes(compute_encoder, batch_strides, 11);
int batch_ndim_vec = batch_shape_vec.size(); int batch_ndim_vec = batch_shape_vec.size();
compute_encoder.set_bytes(batch_ndim_vec, 12); compute_encoder->setBytes(&batch_ndim_vec, sizeof(int), 12);
compute_encoder.set_vector_bytes(batch_shape_vec, 13); set_vector_bytes(compute_encoder, batch_shape_vec, 13);
compute_encoder.set_vector_bytes(batch_strides_vec, 14); set_vector_bytes(compute_encoder, batch_strides_vec, 14);
int batch_ndim_mat = batch_shape_mat.size(); int batch_ndim_mat = batch_shape_mat.size();
compute_encoder.set_bytes(batch_ndim_mat, 15); compute_encoder->setBytes(&batch_ndim_mat, sizeof(int), 15);
compute_encoder.set_vector_bytes(batch_shape_mat, 16); set_vector_bytes(compute_encoder, batch_shape_mat, 16);
compute_encoder.set_vector_bytes(batch_strides_mat, 17); set_vector_bytes(compute_encoder, batch_strides_mat, 17);
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix)); compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix)); compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
return; return;
@@ -1782,7 +1788,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
wm, wm,
wn); wn);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle // Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn; int tn = (N + bn - 1) / bn;
@@ -1821,10 +1827,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4); compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
compute_encoder.set_vector_bytes(batch_shape, 6); set_vector_bytes(compute_encoder, batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7); set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.set_input_array(lhs_indices, 10); compute_encoder.set_input_array(lhs_indices, 10);
compute_encoder.set_input_array(rhs_indices, 11); compute_encoder.set_input_array(rhs_indices, 11);
@@ -1839,11 +1845,11 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
operand_batch_ndim.push_back(0); operand_batch_ndim.push_back(0);
compute_encoder.set_vector_bytes(operand_shape, 13); set_vector_bytes(compute_encoder, operand_shape, 13);
compute_encoder.set_vector_bytes(operand_strides, 14); set_vector_bytes(compute_encoder, operand_strides, 14);
compute_encoder.set_vector_bytes(operand_batch_ndim, 15); set_vector_bytes(compute_encoder, operand_batch_ndim, 15);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }

View File

@@ -1,11 +1,11 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <memory> #include <memory>
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
#include "mlx/utils.h"
namespace mlx::core::metal { namespace mlx::core::metal {
@@ -13,6 +13,20 @@ bool is_available() {
return true; return true;
} }
int max_ops_per_buffer() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
return atoi(buff_str);
} else {
return 10;
}
};
static int max_ops_per_buffer_ = get_val();
return max_ops_per_buffer_;
}
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
inline void check_error(MTL::CommandBuffer* cbuf) { inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) { if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg; std::ostringstream msg;
@@ -63,8 +77,7 @@ std::function<void()> make_task(array arr, bool signal) {
out.set_status(array::Status::evaluated); out.set_status(array::Status::evaluated);
} }
if (signal || if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) {
d.end_encoding(s.index); d.end_encoding(s.index);
if (signal) { if (signal) {
command_buffer->encodeSignalEvent( command_buffer->encodeSignalEvent(

View File

@@ -1,4 +1,3 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
@@ -100,7 +99,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string&, const std::string&,
const std::string&, const std::string&,
const Dtype&) { const array&) {
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
@@ -109,9 +108,8 @@ MTL::ComputePipelineState* get_reduce_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string&, const std::string&,
const std::string&, const std::string&,
const Dtype&, const array&,
const Dtype&, const array&,
const std::string&,
int, int,
int, int,
int) { int) {

View File

@@ -78,15 +78,18 @@ void RMSNorm::eval_gpu(
} }
uint32_t w_stride = w.strides()[0]; uint32_t w_stride = w.strides()[0];
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array( compute_encoder.set_input_array(
x.data_shared_ptr() == nullptr ? out : x, 0); x.data_shared_ptr() == nullptr ? out : x, 0);
compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(w, 1);
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(eps_, 3); compute_encoder->setBytes(&eps_, sizeof(float), 3);
compute_encoder.set_bytes(axis_size, 4); compute_encoder->setBytes(&axis_size, sizeof(int), 4);
compute_encoder.set_bytes(w_stride, 5); compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 5);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder->setThreadgroupMemoryLength(
16 * 8, 0); // minimum of 16 bytes
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
@@ -180,16 +183,16 @@ void RMSNormVJP::eval_gpu(
} }
uint32_t w_stride = w.strides()[0]; uint32_t w_stride = w.strides()[0];
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(x_in_gx ? gx : x, 0); compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2); compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
compute_encoder.set_output_array(gx, 3); compute_encoder.set_output_array(gx, 3);
compute_encoder.set_output_array(gw_temp, 4); compute_encoder.set_output_array(gw_temp, 4);
compute_encoder.set_bytes(eps_, 5); compute_encoder->setBytes(&eps_, sizeof(float), 5);
compute_encoder.set_bytes(axis_size, 6); compute_encoder->setBytes(&axis_size, sizeof(int), 6);
compute_encoder.set_bytes(w_stride, 7); compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
ReductionPlan plan( ReductionPlan plan(
@@ -270,17 +273,17 @@ void LayerNorm::eval_gpu(
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array( compute_encoder.set_input_array(
x.data_shared_ptr() == nullptr ? out : x, 0); x.data_shared_ptr() == nullptr ? out : x, 0);
compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(b, 2); compute_encoder.set_input_array(b, 2);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(eps_, 4); compute_encoder->setBytes(&eps_, sizeof(float), 4);
compute_encoder.set_bytes(axis_size, 5); compute_encoder->setBytes(&axis_size, sizeof(int), 5);
compute_encoder.set_bytes(w_stride, 6); compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6);
compute_encoder.set_bytes(b_stride, 7); compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
@@ -392,16 +395,16 @@ void LayerNormVJP::eval_gpu(
} }
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(x_in_gx ? gx : x, 0); compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2); compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
compute_encoder.set_output_array(gx, 3); compute_encoder.set_output_array(gx, 3);
compute_encoder.set_output_array(gw_temp, 4); compute_encoder.set_output_array(gw_temp, 4);
compute_encoder.set_bytes(eps_, 5); compute_encoder->setBytes(&eps_, sizeof(float), 5);
compute_encoder.set_bytes(axis_size, 6); compute_encoder->setBytes(&axis_size, sizeof(int), 6);
compute_encoder.set_bytes(w_stride, 7); compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
if (gw.ndim() == 1 && gw.size() == axis_size) { if (gw.ndim() == 1 && gw.size() == axis_size) {

View File

@@ -5,7 +5,6 @@
#include <sstream> #include <sstream>
#include "mlx/backend/common/load.h" #include "mlx/backend/common/load.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
@@ -18,10 +17,10 @@
namespace mlx::core { namespace mlx::core {
template <typename T> template <typename T>
void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { void arange_set_scalars(T start, T next, CommandEncoder& enc) {
enc.set_bytes(start, 0); enc->setBytes(&start, sizeof(T), 0);
T step = next - start; T step = next - start;
enc.set_bytes(step, 1); enc->setBytes(&step, sizeof(T), 1);
} }
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) { void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -38,7 +37,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size( MTL::Size group_dims = MTL::Size(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
switch (out.dtype()) { switch (out.dtype()) {
case bool_: // unsupported case bool_: // unsupported
@@ -81,7 +80,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) { void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -130,25 +129,25 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
size_t n_threads = out.size() * thread_group_size; size_t n_threads = out.size() * thread_group_size;
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1); MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
if (ndim == 0) { if (ndim == 0) {
// Pass place holders so metal doesn't complain // Pass place holders so metal doesn't complain
int shape_ = 0; int shape_ = 0;
size_t stride_ = 0; size_t stride_ = 0;
compute_encoder.set_bytes(shape_, 2); compute_encoder->setBytes(&shape_, sizeof(int), 2);
compute_encoder.set_bytes(stride_, 3); compute_encoder->setBytes(&stride_, sizeof(size_t), 3);
compute_encoder.set_bytes(stride_, 4); compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else { } else {
compute_encoder.set_vector_bytes(shape, 2); compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
compute_encoder.set_vector_bytes(in_strides, 3); compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3);
compute_encoder.set_vector_bytes(out_strides, 4); compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
} }
compute_encoder.set_bytes(ndim, 5); compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
compute_encoder.set_bytes(axis_stride, 6); compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
compute_encoder.set_bytes(axis_size, 7); compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
} }
@@ -170,17 +169,6 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
concatenate_gpu(inputs, out, axis_, stream()); concatenate_gpu(inputs, out, axis_, stream());
} }
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
move_or_copy(in, out);
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) { void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out); eval(inputs, out);
} }
@@ -285,22 +273,24 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
// organize into grid nkeys x elem_per_key // organize into grid nkeys x elem_per_key
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1); MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
auto group_dims = get_block_dims(num_keys, half_size + odd, 1); MTL::Size group_dims = MTL::Size(1, thread_group_size, 1);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(keys, 0); compute_encoder.set_input_array(keys, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(odd, 2); compute_encoder->setBytes(&odd, sizeof(bool), 2);
compute_encoder.set_bytes(bytes_per_key, 3); compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3);
if (!keys.flags().row_contiguous) { if (!keys.flags().row_contiguous) {
int ndim = keys.ndim(); int ndim = keys.ndim();
compute_encoder.set_bytes(ndim, 4); compute_encoder->setBytes(&ndim, sizeof(int), 4);
compute_encoder.set_vector_bytes(keys.shape(), 5); compute_encoder->setBytes(
compute_encoder.set_vector_bytes(keys.strides(), 6); keys.shape().data(), keys.ndim() * sizeof(int), 5);
compute_encoder->setBytes(
keys.strides().data(), keys.ndim() * sizeof(size_t), 6);
} }
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) { void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -355,7 +345,7 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& upd = inputs[1]; auto& upd = inputs[1];
if (upd.size() == 0) { if (upd.size() == 0) {
move_or_copy(in, out); out.copy_shared_buffer(in);
return; return;
} }
@@ -428,12 +418,12 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 || if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
in.flags().row_contiguous) { in.flags().row_contiguous) {
auto strides = in.strides(); auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) { for (int i = 0; i < strides.size() - 1; ++i) {
strides[i] *= ibytes; strides[i] *= ibytes;
strides[i] /= obytes; strides[i] /= obytes;
} }
move_or_copy( out.copy_shared_buffer(
in, out, strides, in.flags(), in.data_size() * ibytes / obytes); in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else { } else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {}); auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes())); tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));

View File

@@ -10,7 +10,6 @@
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
@@ -102,31 +101,31 @@ void launch_qmm(
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto kernel = get_quantized_kernel(d, kname.str(), template_def); auto kernel = get_quantized_kernel(d, kname.str(), template_def);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(w, 0); compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1); compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2); compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3); compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4); compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(D, 5); compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder.set_bytes(O, 6); compute_encoder->setBytes(&O, sizeof(int), 6);
int offset = 7; int offset = 7;
if (matrix) { if (matrix) {
compute_encoder.set_bytes(B, 7); compute_encoder->setBytes(&B, sizeof(int), 7);
offset += 1; offset += 1;
} }
if (batched || gather) { if (batched || gather) {
compute_encoder.set_bytes(x_batch_ndims, offset); compute_encoder->setBytes(&x_batch_ndims, sizeof(int), offset);
compute_encoder.set_vector_bytes(x_shape, offset + 1); set_vector_bytes(compute_encoder, x_shape, offset + 1);
compute_encoder.set_vector_bytes(x_strides, offset + 2); set_vector_bytes(compute_encoder, x_strides, offset + 2);
compute_encoder.set_bytes(w_batch_ndims, offset + 3); compute_encoder->setBytes(&w_batch_ndims, sizeof(int), offset + 3);
compute_encoder.set_vector_bytes(w_shape, offset + 4); set_vector_bytes(compute_encoder, w_shape, offset + 4);
compute_encoder.set_vector_bytes(w_strides, offset + 5); set_vector_bytes(compute_encoder, w_strides, offset + 5);
compute_encoder.set_vector_bytes(s_strides, offset + 6); set_vector_bytes(compute_encoder, s_strides, offset + 6);
compute_encoder.set_vector_bytes(b_strides, offset + 7); set_vector_bytes(compute_encoder, b_strides, offset + 7);
} }
if (gather) { if (gather) {
auto& lhs_indices = inputs[4]; auto& lhs_indices = inputs[4];
@@ -138,15 +137,15 @@ void launch_qmm(
auto& lhs_strides = lhs_indices.strides(); auto& lhs_strides = lhs_indices.strides();
auto& rhs_strides = rhs_indices.strides(); auto& rhs_strides = rhs_indices.strides();
compute_encoder.set_bytes(batch_ndims, offset + 8); compute_encoder->setBytes(&batch_ndims, sizeof(int), offset + 8);
compute_encoder.set_vector_bytes(batch_shape, offset + 9); set_vector_bytes(compute_encoder, batch_shape, offset + 9);
compute_encoder.set_input_array(lhs_indices, offset + 10); compute_encoder.set_input_array(lhs_indices, offset + 10);
compute_encoder.set_input_array(rhs_indices, offset + 11); compute_encoder.set_input_array(rhs_indices, offset + 11);
compute_encoder.set_vector_bytes(lhs_strides, offset + 12); set_vector_bytes(compute_encoder, lhs_strides, offset + 12);
compute_encoder.set_vector_bytes(rhs_strides, offset + 13); set_vector_bytes(compute_encoder, rhs_strides, offset + 13);
} }
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
@@ -237,27 +236,27 @@ void qvm_split_k(
// Encode and dispatch kernel // Encode and dispatch kernel
auto kernel = get_quantized_kernel(d, kname.str(), template_def); auto kernel = get_quantized_kernel(d, kname.str(), template_def);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(w, 0); compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1); compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2); compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3); compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(intermediate, 4); compute_encoder.set_output_array(intermediate, 4);
compute_encoder.set_bytes(split_D, 5); compute_encoder->setBytes(&split_D, sizeof(int), 5);
compute_encoder.set_bytes(O, 6); compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder.set_bytes(x_batch_ndims, 7); compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 7);
compute_encoder.set_vector_bytes(x_shape, 8); set_vector_bytes(compute_encoder, x_shape, 8);
compute_encoder.set_vector_bytes(x_strides, 9); set_vector_bytes(compute_encoder, x_strides, 9);
compute_encoder.set_bytes(w_batch_ndims, 10); compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 10);
compute_encoder.set_vector_bytes(w_shape, 11); set_vector_bytes(compute_encoder, w_shape, 11);
compute_encoder.set_vector_bytes(w_strides, 12); set_vector_bytes(compute_encoder, w_strides, 12);
compute_encoder.set_vector_bytes(s_strides, 13); set_vector_bytes(compute_encoder, s_strides, 13);
compute_encoder.set_vector_bytes(b_strides, 14); set_vector_bytes(compute_encoder, b_strides, 14);
compute_encoder.set_bytes(final_block_size, 15); compute_encoder->setBytes(&final_block_size, sizeof(int), 15);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
int axis = intermediate.ndim() - 3; int axis = intermediate.ndim() - 3;
@@ -299,7 +298,7 @@ void qmm_op(
bool quad = false; bool quad = false;
if (transpose) { if (transpose) {
if (B < 6 && (D == 128 || D == 64) && is_power_of_2(bits)) { if (B < 6 && (D == 128 || D == 64)) {
name += "qmv_quad"; name += "qmv_quad";
constexpr int quads_per_simd = 8; constexpr int quads_per_simd = 8;
constexpr int results_per_quadgroup = 8; constexpr int results_per_quadgroup = 8;
@@ -392,6 +391,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
void fast::AffineQuantize::eval_gpu( void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
bool compute_scale_bias = inputs.size() == 1;
auto& w_pre = inputs[0]; auto& w_pre = inputs[0];
auto& out = outputs[0]; auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -414,7 +415,7 @@ void fast::AffineQuantize::eval_gpu(
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_input_array(w, 0); compute_encoder.set_input_array(w, 0);
if (dequantize_) { if (!compute_scale_bias) {
auto& scales_pre = inputs[1]; auto& scales_pre = inputs[1];
auto& biases_pre = inputs[2]; auto& biases_pre = inputs[2];
auto scales = ensure_row_contiguous(scales_pre); auto scales = ensure_row_contiguous(scales_pre);
@@ -435,21 +436,26 @@ void fast::AffineQuantize::eval_gpu(
std::ostringstream kname; std::ostringstream kname;
auto type_string = dequantize_ ? get_type_string(out.dtype()) auto type_string = dequantize_ ? get_type_string(out.dtype())
: get_type_string(w_pre.dtype()); : get_type_string(w_pre.dtype());
auto kernel_func = dequantize_ ? "affine_dequantize" : "affine_quantize"; auto kernel_func = "affine_quantize_scales_biases";
if (dequantize_) {
kernel_func = "affine_dequantize";
} else if (compute_scale_bias) {
kernel_func = "affine_quantize";
}
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_" kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_; << bits_;
auto template_def = get_template_definition( auto template_def = get_template_definition(
kname.str(), kernel_func, type_string, group_size_, bits_); kname.str(), kernel_func, type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def); auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Treat uint32 as uint8 in kernel // Treat uint32 as uint8 in kernel
constexpr int uint8_per_uint32 = 4; constexpr int uint8_per_uint32 = 4;
constexpr int simd_size = 32; constexpr int simd_size = 32;
int packs_per_int = bits_ == 3 ? 8 : bits_ == 6 ? 4 : 8 / bits_; int packs_per_int = 8 / bits_;
int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size; int per_thread = compute_scale_bias ? group_size_ / simd_size : packs_per_int;
size_t nthreads = size_t nthreads =
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread; dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
@@ -465,7 +471,7 @@ void fast::AffineQuantize::eval_gpu(
} }
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides()) MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }

View File

@@ -2,6 +2,7 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <sstream>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
@@ -66,14 +67,17 @@ struct RowReduceArgs {
strides.push_back(0); strides.push_back(0);
} }
compute_encoder.set_bytes(row_size, 2); compute_encoder->setBytes(&row_size, sizeof(size_t), 2);
compute_encoder.set_bytes(non_row_reductions, 3); compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 3);
compute_encoder.set_vector_bytes(shape, 4); compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder.set_vector_bytes(strides, 5); compute_encoder->setBytes(
compute_encoder.set_bytes(ndim, 6); strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder.set_vector_bytes(reduce_shape, 7); compute_encoder->setBytes(&ndim, sizeof(int), 6);
compute_encoder.set_vector_bytes(reduce_strides, 8); compute_encoder->setBytes(
compute_encoder.set_bytes(reduce_ndim, 9); reduce_shape.data(), reduce_shape.size() * sizeof(int), 7);
compute_encoder->setBytes(
reduce_strides.data(), reduce_strides.size() * sizeof(size_t), 8);
compute_encoder->setBytes(&reduce_ndim, sizeof(int), 9);
if (reduce_ndim == 0) { if (reduce_ndim == 0) {
reduce_shape.pop_back(); reduce_shape.pop_back();
@@ -162,15 +166,18 @@ struct ColReduceArgs {
strides.push_back(0); strides.push_back(0);
} }
compute_encoder.set_bytes(reduction_size, 2); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder.set_bytes(reduction_stride, 3); compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder.set_vector_bytes(shape, 4); compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder.set_vector_bytes(strides, 5); compute_encoder->setBytes(
compute_encoder.set_bytes(ndim, 6); strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder.set_vector_bytes(reduce_shape, 7); compute_encoder->setBytes(&ndim, sizeof(int), 6);
compute_encoder.set_vector_bytes(reduce_strides, 8); compute_encoder->setBytes(
compute_encoder.set_bytes(reduce_ndim, 9); reduce_shape.data(), reduce_shape.size() * sizeof(int), 7);
compute_encoder.set_bytes(non_col_reductions, 10); compute_encoder->setBytes(
reduce_strides.data(), reduce_strides.size() * sizeof(size_t), 8);
compute_encoder->setBytes(&reduce_ndim, sizeof(int), 9);
compute_encoder->setBytes(&non_col_reductions, sizeof(size_t), 10);
if (reduce_ndim == 0) { if (reduce_ndim == 0) {
reduce_shape.pop_back(); reduce_shape.pop_back();
@@ -201,16 +208,6 @@ inline bool is_64b_dtype(Dtype dtype) {
return dtype == int64 || dtype == uint64 || dtype == complex64; return dtype == int64 || dtype == uint64 || dtype == complex64;
} }
inline int get_kernel_reduce_ndim(int reduce_ndim) {
if (reduce_ndim <= 1) {
return 1;
} else if (reduce_ndim == 2) {
return 2;
} else {
return 5;
}
}
inline int threadgroup_size_from_row_size(int row_size) { inline int threadgroup_size_from_row_size(int row_size) {
// 1 simdgroup per row smallish rows // 1 simdgroup per row smallish rows
if (row_size <= 512) { if (row_size <= 512) {
@@ -242,51 +239,16 @@ inline auto output_grid_for_col_reduce(
return get_2d_grid_dims(out_shape, out_strides); return get_2d_grid_dims(out_shape, out_strides);
} }
std::pair<Dtype, Dtype> remap_reduce_types(
const array& in,
const std::string& op_name) {
if (op_name == "sum" || op_name == "prod") {
if (issubdtype(in.dtype(), integer)) {
switch (in.dtype().size()) {
case 1:
return {int8, int32};
case 2:
return {int16, int32};
case 4:
return {int32, int32};
case 8:
return {int64, int64};
}
}
if (in.dtype() == bool_) {
return {int8, int32};
}
return {in.dtype(), in.dtype()};
} else if (op_name == "and" || op_name == "or") {
if (in.dtype().size() == 1) {
return {bool_, bool_};
} else if (in.dtype().size() == 2) {
return {int16, bool_};
} else if (in.dtype().size() == 4) {
return {int32, bool_};
} else {
return {int64, bool_};
}
}
return {in.dtype(), in.dtype()};
}
void init_reduce( void init_reduce(
array& out, array& out,
const std::string& op_name, const std::string& op_name,
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
auto [_, out_type] = remap_reduce_types(out, op_name); std::ostringstream kname;
const std::string func_name = "init_reduce"; const std::string func_name = "init_reduce";
std::string kname = func_name; kname << func_name << "_" << op_name << type_to_name(out);
concatenate(kname, "_", op_name, type_to_name(out_type)); auto kernel = get_reduce_init_kernel(d, kname.str(), func_name, op_name, out);
auto kernel = get_reduce_init_kernel(d, kname, func_name, op_name, out_type);
size_t nthreads = out.size(); size_t nthreads = out.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
@@ -294,9 +256,9 @@ void init_reduce(
thread_group_size = nthreads; thread_group_size = nthreads;
} }
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_output_array(out, 0); compute_encoder.set_output_array(out, 0);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
void all_reduce_dispatch( void all_reduce_dispatch(
@@ -307,13 +269,11 @@ void all_reduce_dispatch(
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
// Set the kernel // Set the kernel
auto [in_type, out_type] = remap_reduce_types(in, op_name); std::ostringstream kname;
const std::string func_name = "all_reduce"; const std::string func_name = "all_reduce";
std::string kname = func_name; kname << func_name << "_" << op_name << type_to_name(in);
concatenate(kname, "_", op_name, type_to_name(in_type)); auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out);
auto kernel = get_reduce_kernel( compute_encoder->setComputePipelineState(kernel);
d, kname, func_name, op_name, in_type, out_type, "int64_t");
compute_encoder.set_compute_pipeline_state(kernel);
size_t in_size = in.size(); size_t in_size = in.size();
@@ -325,9 +285,9 @@ void all_reduce_dispatch(
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(in_size, 2); compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder.set_bytes(in_size, 3); compute_encoder->setBytes(&in_size, sizeof(size_t), 3);
compute_encoder.dispatch_threads(grid_dims, grid_dims); compute_encoder.dispatchThreads(grid_dims, grid_dims);
} }
// We need multiple threadgroups so we 'll do it in 2 passes. // We need multiple threadgroups so we 'll do it in 2 passes.
@@ -346,7 +306,7 @@ void all_reduce_dispatch(
} }
// Allocate an intermediate tensor to hold results if needed // Allocate an intermediate tensor to hold results if needed
array intermediate({n_rows}, out_type, nullptr, {}); array intermediate({n_rows}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index); d.add_temporary(intermediate, s.index);
@@ -359,24 +319,24 @@ void all_reduce_dispatch(
MTL::Size group_dims(threadgroup_size, 1, 1); MTL::Size group_dims(threadgroup_size, 1, 1);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1); compute_encoder.set_output_array(intermediate, 1);
compute_encoder.set_bytes(in_size, 2); compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder.set_bytes(row_size, 3); compute_encoder->setBytes(&row_size, sizeof(size_t), 3);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
// 2nd pass // 2nd pass
std::string kname_2nd_pass = func_name; std::ostringstream kname_2nd_pass;
concatenate(kname_2nd_pass, "_", op_name, type_to_name(intermediate)); kname_2nd_pass << func_name << "_" << op_name << type_to_name(intermediate);
auto kernel_2nd_pass = get_reduce_kernel( auto kernel_2nd_pass = get_reduce_kernel(
d, kname_2nd_pass, func_name, op_name, out_type, out_type, "int64_t"); d, kname_2nd_pass.str(), func_name, op_name, intermediate, out);
compute_encoder.set_compute_pipeline_state(kernel_2nd_pass); compute_encoder->setComputePipelineState(kernel_2nd_pass);
size_t intermediate_size = n_rows; size_t intermediate_size = n_rows;
grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1); grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1); group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
compute_encoder.set_input_array(intermediate, 0); compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(intermediate_size, 2); compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2);
compute_encoder.set_bytes(intermediate_size, 3); compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 3);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
} }
@@ -389,31 +349,13 @@ void row_reduce_small(
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
// Set the kernel // Set the kernel
int n = get_kernel_reduce_ndim(args.reduce_ndim); std::ostringstream kname;
auto [in_type, out_type] = remap_reduce_types(in, op_name); int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
const std::string func_name = "row_reduce_small"; const std::string func_name = "row_reduce_small";
std::string kname = func_name; kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
bool large = in.size() > UINT32_MAX; auto kernel =
if (large) { get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
kname += "_large"; compute_encoder->setComputePipelineState(kernel);
}
concatenate(
kname,
"_",
std::to_string(n),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n);
compute_encoder.set_compute_pipeline_state(kernel);
// Figure out the grid dims // Figure out the grid dims
MTL::Size grid_dims; MTL::Size grid_dims;
@@ -433,7 +375,7 @@ void row_reduce_small(
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder); args.encode(compute_encoder);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
void row_reduce_simple( void row_reduce_simple(
@@ -445,14 +387,11 @@ void row_reduce_simple(
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
// Set the kernel // Set the kernel
auto [in_type, out_type] = remap_reduce_types(in, op_name); std::ostringstream kname;
const std::string func_name = "row_reduce_simple"; const std::string func_name = "row_reduce_simple";
std::string kname = func_name; kname << func_name << "_" << op_name << type_to_name(in);
concatenate(kname, "_", op_name, type_to_name(in_type)); auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out);
compute_encoder->setComputePipelineState(kernel);
auto kernel = get_reduce_kernel(
d, kname, func_name, op_name, in_type, out_type, "size_t");
compute_encoder.set_compute_pipeline_state(kernel);
// Figure out the grid dims // Figure out the grid dims
size_t row_size = args.row_size; size_t row_size = args.row_size;
@@ -471,9 +410,9 @@ void row_reduce_simple(
// Launch // Launch
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(row_size, 2); compute_encoder->setBytes(&row_size, sizeof(size_t), 2);
compute_encoder.set_bytes(out_size, 3); compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
void row_reduce_looped( void row_reduce_looped(
@@ -484,33 +423,14 @@ void row_reduce_looped(
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Set the kernel // Set the kernel
int n = get_kernel_reduce_ndim(args.reduce_ndim); std::ostringstream kname;
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
const std::string func_name = "row_reduce_looped"; const std::string func_name = "row_reduce_looped";
std::string kname = func_name; kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
bool large = in.size() > UINT32_MAX; auto kernel =
if (large) { get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
kname += "_large"; compute_encoder->setComputePipelineState(kernel);
}
concatenate(
kname,
"_",
std::to_string(n),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n);
compute_encoder.set_compute_pipeline_state(kernel);
// Figure out the grid // Figure out the grid
auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides()); auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides());
@@ -523,7 +443,7 @@ void row_reduce_looped(
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder); args.encode(compute_encoder);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
void row_reduce_general_dispatch( void row_reduce_general_dispatch(
@@ -561,8 +481,6 @@ void strided_reduce_small(
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Figure out the grid dims // Figure out the grid dims
MTL::Size grid_dims, group_dims; MTL::Size grid_dims, group_dims;
@@ -571,30 +489,13 @@ void strided_reduce_small(
args.reduce_strides.push_back(args.reduction_stride); args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++; args.reduce_ndim++;
int n = get_kernel_reduce_ndim(args.reduce_ndim); int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_small"; const std::string func_name = "col_reduce_small";
std::string kname = func_name; kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
bool large = in.size() > UINT32_MAX; auto kernel =
if (large) { get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
kname += "_large"; compute_encoder->setComputePipelineState(kernel);
}
concatenate(
kname,
"_",
std::to_string(n),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n);
compute_encoder.set_compute_pipeline_state(kernel);
const int n_reads = 4; const int n_reads = 4;
size_t reduction_stride_blocks = size_t reduction_stride_blocks =
@@ -616,7 +517,7 @@ void strided_reduce_small(
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder); args.encode(compute_encoder);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
} }
void strided_reduce_longcolumn( void strided_reduce_longcolumn(
@@ -627,7 +528,6 @@ void strided_reduce_longcolumn(
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
auto [in_type, out_type] = remap_reduce_types(in, op_name);
size_t total_reduction_size = args.reduction_size * args.non_col_reductions; size_t total_reduction_size = args.reduction_size * args.non_col_reductions;
size_t outer_blocks = 32; size_t outer_blocks = 32;
if (total_reduction_size >= 32768) { if (total_reduction_size >= 32768) {
@@ -640,7 +540,7 @@ void strided_reduce_longcolumn(
intermediate_shape.push_back(outer_blocks); intermediate_shape.push_back(outer_blocks);
intermediate_shape.insert( intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end()); intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out_type, nullptr, {}); array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index); d.add_temporary(intermediate, s.index);
@@ -662,37 +562,20 @@ void strided_reduce_longcolumn(
MTL::Size group_dims(threadgroup_x, threadgroup_y, 1); MTL::Size group_dims(threadgroup_x, threadgroup_y, 1);
// Set the kernel // Set the kernel
int n = get_kernel_reduce_ndim(args.reduce_ndim); int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::string func_name = "col_reduce_longcolumn"; std::ostringstream kname;
std::string kname = func_name; const std::string func_name = "col_reduce_longcolumn";
bool large = in.size() > UINT32_MAX; kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
if (large) { auto kernel =
kname += "_large"; get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
} compute_encoder->setComputePipelineState(kernel);
concatenate(
kname,
"_",
std::to_string(n),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n);
compute_encoder.set_compute_pipeline_state(kernel);
// Launch // Launch
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1); compute_encoder.set_output_array(intermediate, 1);
args.encode(compute_encoder); args.encode(compute_encoder);
compute_encoder.set_bytes(out_size, 11); compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Make the 2nd pass arguments and grid_dims // Make the 2nd pass arguments and grid_dims
ColReduceArgs second_args(intermediate); ColReduceArgs second_args(intermediate);
@@ -704,30 +587,24 @@ void strided_reduce_longcolumn(
group_dims = MTL::Size(256, 1, 1); group_dims = MTL::Size(256, 1, 1);
// Set the 2nd kernel // Set the 2nd kernel
func_name = "col_reduce_looped"; const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
kname = func_name; op_name + type_to_name(intermediate);
large = intermediate.size() > UINT32_MAX;
if (large) {
kname += "_large";
}
concatenate(kname, "_1_32_32_reduce_", op_name, type_to_name(intermediate));
kernel = get_reduce_kernel( kernel = get_reduce_kernel(
d, d,
kname, second_kernel,
func_name, "col_reduce_looped",
op_name, op_name,
intermediate.dtype(), intermediate,
out_type, out,
large ? "size_t" : "uint",
1, 1,
32, 32,
32); 32);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(intermediate, 0); compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
second_args.encode(compute_encoder); second_args.encode(compute_encoder);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
void strided_reduce_looped( void strided_reduce_looped(
@@ -738,8 +615,6 @@ void strided_reduce_looped(
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Prepare the arguments for the kernel // Prepare the arguments for the kernel
args.reduce_shape.push_back(args.reduction_size); args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride); args.reduce_strides.push_back(args.reduction_stride);
@@ -757,42 +632,20 @@ void strided_reduce_looped(
MTL::Size group_dims(threadgroup_size, 1, 1); MTL::Size group_dims(threadgroup_size, 1, 1);
// Set the kernel // Set the kernel
int n = get_kernel_reduce_ndim(args.reduce_ndim); int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::string func_name = "col_reduce_looped"; std::ostringstream kname;
std::string kname = func_name; const std::string func_name = "col_reduce_looped";
bool large = in.size() > UINT32_MAX; kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_"
if (large) { << op_name << type_to_name(in);
kname += "_large"; auto kernel =
} get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
concatenate( compute_encoder->setComputePipelineState(kernel);
kname,
"_",
std::to_string(n),
"_",
std::to_string(BM),
"_",
std::to_string(BN),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n,
BM,
BN);
compute_encoder.set_compute_pipeline_state(kernel);
// Launch // Launch
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder); args.encode(compute_encoder);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
void strided_reduce_2pass( void strided_reduce_2pass(
@@ -803,15 +656,13 @@ void strided_reduce_2pass(
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Prepare the temporary accumulator // Prepare the temporary accumulator
std::vector<int> intermediate_shape; std::vector<int> intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1); intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(32); intermediate_shape.push_back(32);
intermediate_shape.insert( intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end()); intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out_type, nullptr, {}); array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index); d.add_temporary(intermediate, s.index);
@@ -834,43 +685,21 @@ void strided_reduce_2pass(
MTL::Size group_dims(threadgroup_size, 1, 1); MTL::Size group_dims(threadgroup_size, 1, 1);
// Set the kernel // Set the kernel
int n = get_kernel_reduce_ndim(args.reduce_ndim); int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::string func_name = "col_reduce_2pass"; std::ostringstream kname;
std::string kname = func_name; const std::string func_name = "col_reduce_2pass";
bool large = in.size() > UINT32_MAX; kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_"
if (large) { << op_name << type_to_name(in);
kname += "_large"; auto kernel =
} get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
concatenate( compute_encoder->setComputePipelineState(kernel);
kname,
"_",
std::to_string(n),
"_",
std::to_string(BM),
"_",
std::to_string(BN),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n,
BM,
BN);
compute_encoder.set_compute_pipeline_state(kernel);
// Launch // Launch
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1); compute_encoder.set_output_array(intermediate, 1);
args.encode(compute_encoder); args.encode(compute_encoder);
compute_encoder.set_bytes(out_size, 11); compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
// Make the 2nd pass arguments and grid_dims // Make the 2nd pass arguments and grid_dims
ColReduceArgs second_args(intermediate); ColReduceArgs second_args(intermediate);
@@ -880,30 +709,24 @@ void strided_reduce_2pass(
grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1); grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1);
// Set the 2nd kernel // Set the 2nd kernel
func_name = "col_reduce_looped"; const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
kname = func_name; op_name + type_to_name(intermediate);
large = intermediate.size() > UINT32_MAX;
if (large) {
kname += "_large";
}
concatenate(kname, "_1_32_32_reduce_", op_name, type_to_name(intermediate));
kernel = get_reduce_kernel( kernel = get_reduce_kernel(
d, d,
kname, second_kernel,
func_name, "col_reduce_looped",
op_name, op_name,
intermediate.dtype(), intermediate,
out_type, out,
large ? "size_t" : "uint",
1, 1,
32, 32,
32); 32);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(intermediate, 0); compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
second_args.encode(compute_encoder); second_args.encode(compute_encoder);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
void strided_reduce_general_dispatch( void strided_reduce_general_dispatch(
@@ -963,7 +786,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
op_name = "sum"; op_name = "sum";
break; break;
case Reduce::Prod: case Reduce::Prod:
op_name = "prod"; op_name = out.dtype() == bool_ ? "and" : "prod";
break; break;
case Reduce::Min: case Reduce::Min:
op_name = out.dtype() == bool_ ? "and" : "min"; op_name = out.dtype() == bool_ ? "and" : "min";

View File

@@ -75,24 +75,24 @@ void RoPE::eval_gpu(
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
float base = std::log2(base_); float base = std::log2(base_);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(donated ? out : in, 0); compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(offset_, 2); compute_encoder->setBytes(&offset_, sizeof(int), 2);
compute_encoder.set_bytes(scale_, 3); compute_encoder->setBytes(&scale_, sizeof(float), 3);
size_t n_batch = in.size() / mat_size; size_t n_batch = in.size() / mat_size;
MTL::Size group_dims; MTL::Size group_dims;
MTL::Size grid_dims; MTL::Size grid_dims;
if (single) { if (single) {
compute_encoder.set_bytes(out_strides, 1, 4); compute_encoder->setBytes(out_strides, sizeof(size_t), 4);
uint32_t dim0 = dims_ / 2; uint32_t dim0 = dims_ / 2;
group_dims = get_block_dims(dim0, n_batch, 1); group_dims = get_block_dims(dim0, n_batch, 1);
grid_dims = MTL::Size(dim0, n_batch, 1); grid_dims = MTL::Size(dim0, n_batch, 1);
} else { } else {
compute_encoder.set_bytes(strides, 3, 4); compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 4);
compute_encoder.set_bytes(out_strides, 3, 5); compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 5);
compute_encoder.set_bytes(n_batch, 6); compute_encoder->setBytes(&n_batch, sizeof(size_t), 6);
uint32_t dim0 = dims_ / 2; uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2); uint32_t dim1 = in.shape(-2);
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread; uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
@@ -104,11 +104,11 @@ void RoPE::eval_gpu(
auto& freqs = inputs[1]; auto& freqs = inputs[1];
compute_encoder.set_input_array(freqs, 10); compute_encoder.set_input_array(freqs, 10);
auto freq_stride = freqs.strides()[0]; auto freq_stride = freqs.strides()[0];
compute_encoder.set_bytes(freq_stride, 11); compute_encoder->setBytes(&freq_stride, sizeof(size_t), 11);
} else { } else {
compute_encoder.set_bytes(base, 10); compute_encoder->setBytes(&base, sizeof(float), 10);
} }
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@@ -6,11 +6,8 @@
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/utils.h"
namespace mlx::core::fast { namespace mlx::core::fast {
@@ -21,92 +18,128 @@ void sdpa_full_self_attention_metal(
const array& q, const array& q,
const array& k, const array& k,
const array& v, const array& v,
const float scale, const float alpha,
array& o) { array& out) {
using namespace mlx::steel; std::ostringstream kname_self_attention;
kname_self_attention << "steel_gemm_attention_";
int wm = 4; constexpr const int bm = 16;
int wn = 1; constexpr const int bn = 16;
const int bk = q.shape(-1); // already forced to be 64 or 128
int bd = q.shape(-1); if (bk != 64 && bk != 128) {
int bq = 32; throw std::runtime_error(
int bk = bd < 128 ? 32 : 16; "[ScaledDotProductAttention::eval_gpu]: hidden dim: expected either 64, 128");
}
int B = q.shape(0); constexpr const int wm = 2;
int H = q.shape(1); constexpr const int wn = 2;
int D = q.shape(3);
int gqa_factor = q.shape(1) / k.shape(1);
int qL = q.shape(2); std::string delimiter = "_";
int kL = k.shape(2);
const bool align_Q = (qL % bq) == 0; kname_self_attention << "bm_" + std::to_string(bm) + delimiter;
const bool align_K = (kL % bk) == 0; kname_self_attention << "bn_" + std::to_string(bn) + delimiter;
kname_self_attention << "bk_" + std::to_string(bk) + delimiter;
metal::MTLFCList func_consts = { for (const auto& arr : {k, v, out}) {
{&align_Q, MTL::DataType::DataTypeBool, 200}, if (arr.dtype() != q.dtype()) {
{&align_K, MTL::DataType::DataTypeBool, 201}, throw std::runtime_error(
}; "[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o");
}
}
std::ostringstream kname; if (q.dtype() == float32) {
// clang-format off kname_self_attention << "itype" + delimiter + "float";
kname << "steel_attention_" } else if (q.dtype() == float16) {
<< type_to_name(q) kname_self_attention << "itype" + delimiter + "half";
<< "_bq" << bq } else {
<< "_bk" << bk throw std::runtime_error(
<< "_bd" << bd "[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16.");
<< "_wm" << wm << "_wn" << wn; // clang-format on }
std::string base_name = kname.str();
// clang-format off
kname << "_align_Q_" << (align_Q ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); auto kernel = d.get_kernel(kname_self_attention.str());
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
const int NQ = (qL + bq - 1) / bq; uint hidden_dim = q.shape(-1);
const int NK = (kL + bk - 1) / bk; uint qseq = q.shape(-2);
uint qheads = q.shape(-3);
const int NQ_aligned = qL / bq; const uint64_t KV_sequence_length = k.shape(-2);
const int NK_aligned = kL / bk; const uint query_sequence_length = q.shape(-2);
const uint n_q_heads = q.shape(1);
const uint n_kv_heads = k.shape(1);
AttnParams params{ const int M = q.shape(-2);
/* int B = */ B, const int N = M;
/* int H = */ H, const int K = q.shape(-1);
/* int D = */ D, const size_t batch_size_out = q.shape(0) * q.shape(1);
/* int qL = */ qL, const std::vector<int> batch_shape = {q.shape(0) * q.shape(1)};
/* int kL = */ kL, const int dk = q.shape(-1);
const int ldq = dk;
const int ldk = dk;
const int ldv = dk;
const int lds = bn;
const int ldo = dk;
/* int gqa_factor = */ gqa_factor, int tn = 1;
/* float scale = */ scale, int tm = (M + bm - 1) / bm;
/* int NQ = */ NQ, const int batch_stride_q = dk * query_sequence_length;
/* int NK = */ NK, const int batch_stride_k = dk * query_sequence_length;
const int batch_stride_v = dk * query_sequence_length;
const int batch_stride_o = dk * query_sequence_length;
const int swizzle_log = 0;
const int gemm_n_iterations_aligned = (N + bn - 1) / bn;
const int gemm_k_iterations_aligned = (K + bk - 1) / bk;
const int gemm_sv_m_block_iterations = (M + bm - 1) / bm;
const int batch_ndim = int(batch_shape.size());
/* int NQ_aligned = */ NQ_aligned, MLXFastAttentionParams params{
/* int NK_aligned = */ NK_aligned, (int)M,
(int)N,
(int)K,
ldq,
ldk,
ldv,
lds,
ldo,
tn,
tm,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_o,
swizzle_log,
gemm_n_iterations_aligned,
gemm_k_iterations_aligned,
gemm_sv_m_block_iterations,
batch_ndim,
alpha};
/* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, const std::vector<size_t> batch_strides = {
/* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, (size_t)batch_stride_q,
/* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, (size_t)batch_stride_k,
/* size_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; (size_t)batch_stride_v,
(size_t)batch_stride_o};
compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(q, 0);
compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(v, 2); compute_encoder.set_input_array(v, 2);
compute_encoder.set_output_array(o, 3); compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
MTL::Size grid_dims = MTL::Size(NQ, H, B); compute_encoder->setBytes(&params, sizeof(MLXFastAttentionParams), 4);
compute_encoder->setBytes(
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out);
MTL::Size group_dims = MTL::Size(32, wm, wn); MTL::Size group_dims = MTL::Size(32, wm, wn);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
} }
void sdpa_vector( void sdpa_vector(
@@ -137,109 +170,21 @@ void sdpa_vector(
// Get the kernel // Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Set its arguments // Set its arguments
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(v, 2); compute_encoder.set_input_array(v, 2);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(gqa_factor, 4); compute_encoder->setBytes(&gqa_factor, sizeof(int), 4);
compute_encoder.set_bytes(N, 5); compute_encoder->setBytes(&N, sizeof(int), 5);
compute_encoder.set_bytes(k_stride, 6); compute_encoder->setBytes(&k_stride, sizeof(size_t), 6);
compute_encoder.set_bytes(v_stride, 7); compute_encoder->setBytes(&v_stride, sizeof(size_t), 7);
compute_encoder.set_bytes(scale, 8); compute_encoder->setBytes(&scale, sizeof(float), 8);
// Launch // Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void sdpa_vector_2pass(
const Stream& s,
metal::Device& d,
const array& q,
const array& k,
const array& v,
array& out,
float scale) {
// Set the kernel name
std::string kname;
kname.reserve(64);
kname += "sdpa_vector_2pass_1_";
kname += get_type_string(q.dtype());
kname += "_";
kname += std::to_string(q.shape(-1));
// Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2);
int blocks = 32;
int B = q.shape(0) * q.shape(1);
size_t k_stride = k.strides()[1];
size_t v_stride = v.strides()[1];
MTL::Size group_dims(8 * 32, 1, 1);
MTL::Size grid_dims(1, B, blocks);
// Allocate the intermediates
std::vector<int> intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1);
intermediate_shape.push_back(blocks);
intermediate_shape.push_back(out.shape().back());
array intermediate(intermediate_shape, float32, nullptr, {});
intermediate_shape.pop_back();
array sums(intermediate_shape, float32, nullptr, {});
array maxs(std::move(intermediate_shape), float32, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
sums.set_data(allocator::malloc_or_wait(sums.nbytes()));
maxs.set_data(allocator::malloc_or_wait(maxs.nbytes()));
d.add_temporary(intermediate, s.index);
d.add_temporary(sums, s.index);
d.add_temporary(maxs, s.index);
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(v, 2);
compute_encoder.set_output_array(intermediate, 3);
compute_encoder.set_output_array(sums, 4);
compute_encoder.set_output_array(maxs, 5);
compute_encoder.set_bytes(gqa_factor, 6);
compute_encoder.set_bytes(N, 7);
compute_encoder.set_bytes(k_stride, 8);
compute_encoder.set_bytes(v_stride, 9);
compute_encoder.set_bytes(scale, 10);
// Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
// Final pass
kname.clear();
kname += "sdpa_vector_2pass_2_";
kname += get_type_string(q.dtype());
kname += "_";
kname += std::to_string(q.shape(-1));
// Get the kernel
kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_input_array(sums, 1);
compute_encoder.set_input_array(maxs, 2);
compute_encoder.set_output_array(out, 3);
// Launch
group_dims = MTL::Size(1024, 1, 1);
grid_dims = MTL::Size(1, B, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
} // namespace } // namespace
@@ -261,14 +206,12 @@ void ScaledDotProductAttention::eval_gpu(
// Define some copy functions to ensure the layout of the inputs is as // Define some copy functions to ensure the layout of the inputs is as
// expected. // expected.
copies.reserve(3); auto copy_unless = [&copies, &s](auto predicate, const array& arr) {
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) { if (!predicate(arr)) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s); copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy); copies.push_back(arr_copy);
return copies.back(); return arr_copy;
} else { } else {
return arr; return arr;
} }
@@ -297,9 +240,9 @@ void ScaledDotProductAttention::eval_gpu(
// We are in vector mode ie single query // We are in vector mode ie single query
if (q_pre.shape(2) == 1) { if (q_pre.shape(2) == 1) {
const auto& q = copy_unless(is_contiguous, q_pre); auto q = copy_unless(is_contiguous, q_pre);
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre); auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); auto v = copy_unless(is_contiguous_except_seq_len, v_pre);
// Donate the query if possible // Donate the query if possible
if (q.is_donatable()) { if (q.is_donatable()) {
@@ -308,41 +251,15 @@ void ScaledDotProductAttention::eval_gpu(
o.set_data(allocator::malloc_or_wait(o.nbytes())); o.set_data(allocator::malloc_or_wait(o.nbytes()));
} }
// We route to the 2 pass fused attention if sdpa_vector(s, d, q, k, v, o, scale_);
// - The device is large and the sequence length long
// - The sequence length is even longer and we have gqa
char devc = d.get_architecture().back();
if ((devc == 'd' && k.shape(2) >= 1024) ||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
sdpa_vector_2pass(s, d, q, k, v, o, scale_);
} else {
sdpa_vector(s, d, q, k, v, o, scale_);
}
} }
// Full attention mode // Full attention mode
else { else {
const auto& q = copy_unless(is_matrix_contiguous, q_pre); auto q = copy_unless(is_matrix_contiguous, q_pre);
const auto& k = copy_unless(is_matrix_contiguous, k_pre); auto k = copy_unless(is_matrix_contiguous, k_pre);
const auto& v = copy_unless(is_matrix_contiguous, v_pre); auto v = copy_unless(is_matrix_contiguous, v_pre);
o.set_data(allocator::malloc_or_wait(o.nbytes()));
size_t str_oD = 1;
size_t str_oH = o.shape(3);
size_t str_oL = o.shape(1) * str_oH;
size_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ 0,
/* bool col_contiguous = */ 0,
};
o.set_data(
allocator::malloc_or_wait(o.nbytes()),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o); sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
} }

View File

@@ -68,12 +68,12 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
if (contiguous) { if (contiguous) {
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array( compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0); in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
size_t size = in.shape(axis_); size_t size = in.shape(axis_);
compute_encoder.set_bytes(size, 2); compute_encoder->setBytes(&size, sizeof(size_t), 2);
// Compute the thread grid // Compute the thread grid
int n_reads = (in.itemsize() <= 4) ? 4 : 2; int n_reads = (in.itemsize() <= 4) ? 4 : 2;
@@ -95,10 +95,10 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size grid_dims( MTL::Size grid_dims(
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height); thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
MTL::Size group_dims(thread_group_size, 1, 1); MTL::Size group_dims(thread_group_size, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} else { } else {
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array( compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0); in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
@@ -107,9 +107,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
int bm = 32; int bm = 32;
int bn = 32; int bn = 32;
size_t stride_blocks = (stride + bn - 1) / bn; size_t stride_blocks = (stride + bn - 1) / bn;
compute_encoder.set_bytes(size, 2); compute_encoder->setBytes(&size, sizeof(size_t), 2);
compute_encoder.set_bytes(stride, 3); compute_encoder->setBytes(&stride, sizeof(size_t), 3);
compute_encoder.set_bytes(stride_blocks, 4); compute_encoder->setBytes(&stride_blocks, sizeof(size_t), 4);
// Compute the thread grid // Compute the thread grid
int n_reads = (in.itemsize() <= 4) ? 4 : 2; int n_reads = (in.itemsize() <= 4) ? 4 : 2;
@@ -125,7 +125,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size grid_dims( MTL::Size grid_dims(
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height); thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
MTL::Size group_dims(thread_group_size, 1, 1); MTL::Size group_dims(thread_group_size, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);

View File

@@ -81,12 +81,12 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
group_dims = MTL::Size(threadgroup_size, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1);
} }
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array( compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0); in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(axis_size, 2); compute_encoder->setBytes(&axis_size, sizeof(int), 2);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);

View File

@@ -68,29 +68,29 @@ void single_block_sort(
// Prepare command encoder // Prepare command encoder
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Set inputs // Set inputs
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(size_sorted_axis, 2); compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
compute_encoder.set_bytes(in_stride_sorted_axis, 3); compute_encoder->setBytes(&in_stride_sorted_axis, sizeof(int), 3);
compute_encoder.set_bytes(out_stride_sorted_axis, 4); compute_encoder->setBytes(&out_stride_sorted_axis, sizeof(int), 4);
if (contiguous) { if (contiguous) {
compute_encoder.set_bytes(in_stride_segment_axis, 5); compute_encoder->setBytes(&in_stride_segment_axis, sizeof(int), 5);
compute_encoder.set_bytes(out_stride_segment_axis, 6); compute_encoder->setBytes(&out_stride_segment_axis, sizeof(int), 6);
} else { } else {
compute_encoder.set_bytes(nc_dim, 5); compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder.set_vector_bytes(nc_shape, 6); compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
compute_encoder.set_vector_bytes(in_nc_str, 7); compute_encoder->setBytes(in_nc_str.data(), nc_dim * sizeof(size_t), 7);
compute_encoder.set_vector_bytes(out_nc_str, 8); compute_encoder->setBytes(out_nc_str.data(), nc_dim * sizeof(size_t), 8);
} }
MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1); MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
} }
void multi_block_sort( void multi_block_sort(
@@ -152,21 +152,22 @@ void multi_block_sort(
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn; << type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
auto kernel = auto kernel =
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(dev_vals_0, 1); compute_encoder.set_output_array(dev_vals_0, 1);
compute_encoder.set_output_array(dev_idxs_0, 2); compute_encoder.set_output_array(dev_idxs_0, 2);
compute_encoder.set_bytes(size_sorted_axis, 3); compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder.set_bytes(stride_sorted_axis, 4); compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
compute_encoder.set_bytes(nc_dim, 5); compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder.set_vector_bytes(nc_shape, 6); compute_encoder->setBytes(
compute_encoder.set_vector_bytes(nc_str, 7); nc_shape.data(), nc_shape.size() * sizeof(int), 6);
compute_encoder->setBytes(nc_str.data(), nc_str.size() * sizeof(size_t), 7);
MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1); MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
} }
// Do merges // Do merges
@@ -193,19 +194,19 @@ void multi_block_sort(
auto kernel = auto kernel =
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_output_array(block_partitions, 0); compute_encoder.set_output_array(block_partitions, 0);
compute_encoder.set_input_array(dev_vals_in, 1); compute_encoder.set_input_array(dev_vals_in, 1);
compute_encoder.set_input_array(dev_idxs_in, 2); compute_encoder.set_input_array(dev_idxs_in, 2);
compute_encoder.set_bytes(size_sorted_axis, 3); compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder.set_bytes(merge_tiles, 4); compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
compute_encoder.set_bytes(n_blocks, 5); compute_encoder->setBytes(&n_blocks, sizeof(int), 5);
MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1); MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1); MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
} }
// Do merge // Do merge
@@ -216,21 +217,21 @@ void multi_block_sort(
auto kernel = auto kernel =
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(block_partitions, 0); compute_encoder.set_input_array(block_partitions, 0);
compute_encoder.set_input_array(dev_vals_in, 1); compute_encoder.set_input_array(dev_vals_in, 1);
compute_encoder.set_input_array(dev_idxs_in, 2); compute_encoder.set_input_array(dev_idxs_in, 2);
compute_encoder.set_output_array(dev_vals_out, 3); compute_encoder.set_output_array(dev_vals_out, 3);
compute_encoder.set_output_array(dev_idxs_out, 4); compute_encoder.set_output_array(dev_idxs_out, 4);
compute_encoder.set_bytes(size_sorted_axis, 5); compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5);
compute_encoder.set_bytes(merge_tiles, 6); compute_encoder->setBytes(&merge_tiles, sizeof(int), 6);
compute_encoder.set_bytes(n_blocks, 7); compute_encoder->setBytes(&n_blocks, sizeof(int), 7);
MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1); MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
} }
} }

Some files were not shown because too many files have changed in this diff Show More