Compare commits

...

5 Commits

Author SHA1 Message Date
Awni Hannun
bb303c45a5 version (#1617) 2024-11-22 12:00:03 -08:00
Alex Barron
6f7986d592 Cleaner qmv/qvm (#1616) 2024-11-22 11:14:08 -08:00
Awni Hannun
7cbb4aef17 Doc fix (#1615) 2024-11-22 11:12:25 -08:00
Jagrit Digani
02bec0bb6d Matrix Attention kernel (#1610)
* Rough INIT

* [WIP]: Loading and Matmuls added

* [WIP]: Reductions and min working aligned kernel at headdim = 64

* [WIP] Added headdim 80 for testing

* [WIP] Update dispatch params for testing

* [WIP] Add support for unaligned seq lengths - still looks messy

* Update sdpa_benchmarks

* Update sdpa_benchmarks

* Update sdpa_benchmarks

* Enable gqa support

* Update benchmark and switch off 128 headdim

* Update headdim 128 tuning

* Remove older fast attention code. Write out O strided

* Disable hd=128 until further optimizations

* Enable bf16

* Fix data size bug

* Enable attn build outside of jit
2024-11-22 10:34:05 -08:00
Alex Barron
c79f6a4a8c 3 and 6 bit quantization (#1613)
* Support 3 and 6 bit quantization
2024-11-22 10:22:13 -08:00
29 changed files with 2700 additions and 1631 deletions

View File

@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.20.0)
set(MLX_VERSION 0.21.0)
endif()
# --------------------- Processor tests -------------------------

View File

@@ -1,62 +1,189 @@
# Copyright © 2024 Apple Inc.
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
from time_utils import time_fn
import numpy as np
MAX_SEQ = 300
START_SEQ = 100
SEQ_INCREMENT = 50
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 5
N_iter_bench = 40
N_iter_func = 8
def time_self_attention_primitives():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def bench(f, *args):
for i in range(N_warmup):
f(*args)
def sdpa_primitives(qs, ks, vs, alpha):
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ vs
return o
time_fn(sdpa_primitives, q, k, v, scale)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(*args)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def time_self_attention_sdpa():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def mlx_sdpa_fused_inner(q, k, v, scale):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
def sdpa_fused(qs, ks, vs, alpha):
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
return o
time_fn(sdpa_fused, q, k, v, scale)
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
n_kv_heads = k.shape[-3]
n_repeats = n_q_heads // n_kv_heads
B = q.shape[0]
L = q.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
k = mx.expand_dims(k, 2)
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if f32softmax:
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
else:
scores = mx.softmax(scores, axis=-1)
out = scores @ v
if n_repeats > 1:
out = mx.reshape(out, [B, n_q_heads, L, -1])
return out
def mlx_spda_unfused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def mlx_spda_fused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
shape_q = (
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
)
shape_kv = (
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
)
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
scale = math.sqrt(1.0 / head_dim)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
if transpose:
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
)
return time_mlx_fused, time_mlx_unfused
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
time_self_attention_sdpa()
time_self_attention_primitives()
dtypes = ("float16", "float32")[:1]
transposes = (False,)
# fmt: off
shapes_64 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 32, 32, 64, 32, 32),
( 1, 64, 64, 64, 32, 32),
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 32),
( 1, 2048, 2048, 64, 32, 32),
( 1, 4096, 4096, 64, 32, 32),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 32),
( 1, 2048, 2048, 80, 32, 32),
( 1, 4096, 4096, 80, 32, 32),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 32),
( 1, 2048, 2048, 128, 32, 32),
( 1, 4096, 4096, 128, 32, 32),
)
# fmt: on
shapes = shapes_64 + shapes_80 + shapes_128
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
for dtype in dtypes:
for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
np_dtype = getattr(np, dtype)
time_mlx_fused, time_mlx_unfused = bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)

View File

@@ -12,5 +12,4 @@ Fast
layer_norm
rope
scaled_dot_product_attention
affine_quantize
metal_kernel

View File

@@ -12,6 +12,7 @@ Layers
ALiBi
AvgPool1d
AvgPool2d
AvgPool3d
BatchNorm
CELU
Conv1d
@@ -41,6 +42,7 @@ Layers
LSTM
MaxPool1d
MaxPool2d
MaxPool3d
Mish
MultiHeadAttention
PReLU

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.24
mlx>=0.18.1
mlx>=0.21.0
nanobind==2.2.0

View File

@@ -6,11 +6,34 @@
#include "mlx/backend/common/ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
template <typename T, int bits>
void extract_bits(const uint8_t* w_in, T* w_out) {
assert(bits == 3 || bits == 6);
if (bits == 3) {
w_out[0] = static_cast<T>(w_in[0] & 0x7);
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
w_out[2] = static_cast<T>(((w_in[0] & 0xc0) >> 6) + ((w_in[1] & 0x1) << 2));
w_out[3] = static_cast<T>((w_in[1] & 0xe) >> 1);
w_out[4] = static_cast<T>((w_in[1] & 0x70) >> 4);
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
} else if (bits == 6) {
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
w_out[1] =
static_cast<T>(((w_in[0] >> 6) & 0x03) + ((w_in[1] & 0x0f) << 2));
w_out[2] =
static_cast<T>(((w_in[1] >> 4) & 0x0f) + ((w_in[2] & 0x03) << 4));
w_out[3] = static_cast<T>((w_in[2] >> 2) & 0x3f);
}
}
template <typename T, int bits, int group_size>
void _qmm(
T* result,
@@ -22,13 +45,12 @@ void _qmm(
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int packs_in_group = group_size / pack_factor;
const int Ng = N / group_size;
const int Nw = N / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const uint8_t* w_local = (const uint8_t*)w;
const T* scales_local = scales;
const T* biases_local = biases;
@@ -42,13 +64,25 @@ void _qmm(
T scale = *scales_local++;
T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) {
uint32_t wi = *w_local++;
if (bits == 3 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= bits;
for (int p = 0; p < pack_factor; p++) {
(*result_local++) += xi * (scale * wl[p] + bias);
}
w_local += bytes_per_pack;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
if (bits != 8) {
wi >>= bits;
}
}
}
}
}
@@ -69,13 +103,12 @@ void _qmm_t(
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const uint8_t* w_local = (const uint8_t*)w;
const T* scales_local = scales;
const T* biases_local = biases;
@@ -87,12 +120,26 @@ void _qmm_t(
T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) {
uint32_t wi = *w_local++;
if (bits == 3 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= bits;
for (int p = 0; p < pack_factor; p++) {
sum += x_local[p] * (scale * wl[p] + bias);
}
w_local += bytes_per_pack;
x_local += pack_factor;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum +=
(*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
if (bits != 8) {
wi >>= bits;
}
}
}
}
}
@@ -104,6 +151,55 @@ void _qmm_t(
}
}
template <typename T, int bits, int group_size>
void _qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
}
}
template <typename T, int bits>
void _qmm_dispatch_group(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
int group_size,
bool transposed_w) {
switch (group_size) {
case 32:
_qmm_dispatch_transpose<T, bits, 32>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 64:
_qmm_dispatch_transpose<T, bits, 64>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 128:
_qmm_dispatch_transpose<T, bits, 128>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
default:
throw std::invalid_argument(
"Quantization group size must be 32, 64 or 128.");
}
}
template <typename T>
void _qmm_dispatch_typed(
T* result,
@@ -118,79 +214,29 @@ void _qmm_dispatch_typed(
int bits,
bool transposed_w) {
switch (bits) {
case 2: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 4: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 8: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 2:
_qmm_dispatch_group<T, 2>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 3:
_qmm_dispatch_group<T, 3>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 4:
_qmm_dispatch_group<T, 4>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 6:
_qmm_dispatch_group<T, 6>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 8:
_qmm_dispatch_group<T, 8>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
default:
throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8.");
}
std::ostringstream msg;
msg << "Quantization type not supported. Provided bits=" << bits
<< " and group_size=" << group_size
<< ". The supported options are bits in "
<< "{2, 4, 8} and group_size in {64, 128}.";
throw std::invalid_argument(msg.str());
}
void _qmm_dispatch(
@@ -406,51 +452,52 @@ void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
transpose_);
}
template <typename T>
template <typename T, typename U>
void quantize(
const array& w_,
array& out_,
array& scales_,
array& biases_,
int bits,
int group_size,
bool compute_scale_bias) {
int group_size) {
const T* w = w_.data<T>();
auto out = out_.data<U>();
T* scales = scales_.data<T>();
T* biases = biases_.data<T>();
auto out = out_.data<uint32_t>();
T n_bins = (1 << bits) - 1;
T eps = 1e-7;
int el_per_int = 32 / bits;
int int_per_group = group_size / el_per_int;
bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
int bytes_per_pack = power_of_2_bits ? 1 : 3;
int int_per_group = group_size * bytes_per_pack / el_per_int;
size_t n_groups = w_.size() / group_size;
for (size_t i = 0; i < n_groups; ++i) {
size_t w_idx = i * group_size;
if (compute_scale_bias) {
T w_min = std::numeric_limits<float>::infinity();
T w_max = -w_min;
for (int j = 0; j < group_size; ++j) {
w_max = std::max(w_max, w[w_idx + j]);
w_min = std::min(w_min, w[w_idx + j]);
}
bool mask = std::abs(w_min) > std::abs(w_max);
T scale = std::max(T((w_max - w_min) / n_bins), eps);
scale = mask ? scale : -scale;
T w_min = std::numeric_limits<float>::infinity();
T w_max = -w_min;
for (int j = 0; j < group_size; ++j) {
w_max = std::max(w_max, w[w_idx + j]);
w_min = std::min(w_min, w[w_idx + j]);
}
bool mask = std::abs(w_min) > std::abs(w_max);
T scale = std::max(T((w_max - w_min) / n_bins), eps);
scale = mask ? scale : -scale;
auto edge = mask ? w_min : w_max;
auto q0 = std::rint(edge / scale);
if (q0 == 0) {
scales[i] = scale;
biases[i] = 0;
} else {
scales[i] = edge / q0;
biases[i] = edge;
}
auto edge = mask ? w_min : w_max;
auto q0 = std::rint(edge / scale);
if (q0 == 0) {
scales[i] = scale;
biases[i] = 0;
} else {
scales[i] = edge / q0;
biases[i] = edge;
}
size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group; ++j) {
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
uint32_t out_el = 0;
for (int k = 0; k < el_per_int; ++k) {
T w_el = w[w_idx + j * el_per_int + k];
@@ -458,7 +505,13 @@ void quantize(
w_el = std::min(std::max(w_el, T(0)), n_bins);
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
}
out[out_idx + j] = out_el;
if (power_of_2_bits) {
out[out_idx + j] = out_el;
} else {
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
}
}
}
}
@@ -466,8 +519,6 @@ void quantize(
void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
bool compute_scale_bias = inputs.size() == 1;
auto ensure_row_contiguous = [](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
@@ -482,23 +533,29 @@ void fast::AffineQuantize::eval_cpu(
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& scales =
compute_scale_bias ? outputs[1] : const_cast<array&>(inputs[1]);
auto& biases =
compute_scale_bias ? outputs[2] : const_cast<array&>(inputs[2]);
if (compute_scale_bias) {
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
}
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
if (w.dtype() == float16) {
quantize<float16_t>(
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
if (is_power_of_2(bits_)) {
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
quantize<bfloat16_t>(
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
if (is_power_of_2(bits_)) {
quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
quantize<float>(
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
if (is_power_of_2(bits_)) {
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");

View File

@@ -44,9 +44,7 @@ build_kernel(layer_norm)
build_kernel(random)
build_kernel(rms_norm)
build_kernel(rope)
build_kernel(
scaled_dot_product_attention scaled_dot_product_attention_params.h
sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)
build_kernel(scaled_dot_product_attention sdpa_vector.h)
set(STEEL_HEADERS
steel/defines.h
@@ -68,6 +66,24 @@ set(STEEL_HEADERS
steel/utils/type_traits.h
steel/utils/integral_constant.h)
set(STEEL_ATTN_HEADERS
steel/defines.h
steel/utils.h
steel/gemm/gemm.h
steel/gemm/mma.h
steel/gemm/loader.h
steel/gemm/transforms.h
steel/utils/type_traits.h
steel/utils/integral_constant.h
steel/attn/attn.h
steel/attn/loader.h
steel/attn/mma.h
steel/attn/params.h
steel/attn/transforms.h
steel/attn/kernels/steel_attention.h)
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
if(NOT MLX_METAL_JIT)
build_kernel(arange arange.h)
build_kernel(binary binary.h binary_ops.h)

View File

@@ -13,8 +13,8 @@ MLX_MTL_CONST int QUAD_SIZE = 4;
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector(const device T* x, thread U* x_thread) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
U sum = 0;
@@ -28,6 +28,21 @@ inline U load_vector(const device T* x, thread U* x_thread) {
}
}
else if (bits == 3) {
for (int i = 0; i < values_per_thread; i += 8) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
x[i + 6] + x[i + 7];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 8.0f;
x_thread[i + 2] = x[i + 2] / 64.0f;
x_thread[i + 3] = x[i + 3] / 2.0f;
x_thread[i + 4] = x[i + 4] / 16.0f;
x_thread[i + 5] = x[i + 5] / 128.0f;
x_thread[i + 6] = x[i + 6] / 4.0f;
x_thread[i + 7] = x[i + 7] / 32.0f;
}
}
else if (bits == 4) {
for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
@@ -38,6 +53,16 @@ inline U load_vector(const device T* x, thread U* x_thread) {
}
}
else if (bits == 6) {
for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 64.0f;
x_thread[i + 2] = x[i + 2] / 16.0f;
x_thread[i + 3] = x[i + 3] / 4.0f;
}
}
else if (bits == 8) {
for (int i = 0; i < values_per_thread; i++) {
sum += x[i];
@@ -51,8 +76,8 @@ inline U load_vector(const device T* x, thread U* x_thread) {
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
U sum = 0;
@@ -64,8 +89,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
x_thread[i + 2] = x[i + 2] / 16.0f;
x_thread[i + 3] = x[i + 3] / 64.0f;
}
for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
}
else if (bits == 3) {
for (int i = 0; i < N; i += 8) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
x[i + 6] + x[i + 7];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 8.0f;
x_thread[i + 2] = x[i + 2] / 64.0f;
x_thread[i + 3] = x[i + 3] / 2.0f;
x_thread[i + 4] = x[i + 4] / 16.0f;
x_thread[i + 5] = x[i + 5] / 128.0f;
x_thread[i + 6] = x[i + 6] / 4.0f;
x_thread[i + 7] = x[i + 7] / 32.0f;
}
}
@@ -77,8 +115,15 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
x_thread[i + 2] = x[i + 2] / 256.0f;
x_thread[i + 3] = x[i + 3] / 4096.0f;
}
for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
}
else if (bits == 6) {
for (int i = 0; i < N; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 64.0f;
x_thread[i + 2] = x[i + 2] / 16.0f;
x_thread[i + 3] = x[i + 3] / 4.0f;
}
}
@@ -87,9 +132,10 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
sum += x[i];
x_thread[i] = x[i];
}
for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
}
}
for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
}
return sum;
@@ -103,8 +149,8 @@ inline U qdot(
U bias,
U sum) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
U accum = 0;
@@ -118,6 +164,26 @@ inline U qdot(
}
}
else if (bits == 3) {
for (int i = 0; i < (values_per_thread / 8); i++) {
x_thread += 8 * i;
w += 3 * i;
accum += (w[0] & 0x07) * x_thread[0];
accum += (w[0] & 0x38) * x_thread[1];
accum += (w[0] & 0xc0) * x_thread[2];
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
accum += (w[1] & 0x0e) * x_thread[3];
accum += (w[1] & 0x70) * x_thread[4];
accum += (w[1] & 0x80) * x_thread[5];
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
accum += (w[2] & 0x1c) * x_thread[6];
accum += (w[2] & 0xe0) * x_thread[7];
}
}
else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (values_per_thread / 4); i++) {
@@ -129,6 +195,23 @@ inline U qdot(
}
}
else if (bits == 6) {
for (int i = 0; i < (values_per_thread / 4); i++) {
x_thread += 4 * i;
w += 3 * i;
accum += (w[0] & 0x3f) * x_thread[0];
accum += (w[0] & 0xc0) * x_thread[1];
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
accum += (w[1] & 0xf0) * x_thread[2];
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
accum += (w[2] & 0xfc) * x_thread[3];
}
}
else if (bits == 8) {
for (int i = 0; i < values_per_thread; i++) {
accum += x_thread[i] * w[i];
@@ -147,8 +230,8 @@ inline U qdot_safe(
U sum,
int N) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
U accum = 0;
@@ -162,6 +245,26 @@ inline U qdot_safe(
}
}
else if (bits == 3) {
for (int i = 0; i < (N / 8); i++) {
x_thread += 8 * i;
w += 3 * i;
accum += (w[0] & 0x07) * x_thread[0];
accum += (w[0] & 0x38) * x_thread[1];
accum += (w[0] & 0xc0) * x_thread[2];
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
accum += (w[1] & 0x0e) * x_thread[3];
accum += (w[1] & 0x70) * x_thread[4];
accum += (w[1] & 0x80) * x_thread[5];
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
accum += (w[2] & 0x1c) * x_thread[6];
accum += (w[2] & 0xe0) * x_thread[7];
}
}
else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (N / 4); i++) {
@@ -173,6 +276,23 @@ inline U qdot_safe(
}
}
else if (bits == 6) {
for (int i = 0; i < (N / 4); i++) {
x_thread += 4 * i;
w += 3 * i;
accum += (w[0] & 0x3f) * x_thread[0];
accum += (w[0] & 0xc0) * x_thread[1];
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
accum += (w[1] & 0xf0) * x_thread[2];
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
accum += (w[2] & 0xfc) * x_thread[3];
}
}
else if (bits == 8) {
for (int i = 0; i < N; i++) {
accum += x_thread[i] * w[i];
@@ -186,8 +306,8 @@ template <typename U, int values_per_thread, int bits>
inline void
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
if (bits == 2) {
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
@@ -199,12 +319,45 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
}
}
else if (bits == 3) {
for (int i = 0; i < (values_per_thread / 8); i++) {
uint8_t w0 = w[3 * i];
uint8_t w1 = w[3 * i + 1];
uint8_t w2 = w[3 * i + 2];
result[8 * i] += x * ((w0 & 0x7) * scale + bias);
result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
result[8 * i + 2] +=
x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
result[8 * i + 5] +=
x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
}
}
else if (bits == 4) {
U s[2] = {scale, scale / 16.0f};
for (int i = 0; i < (values_per_thread / 2); i++) {
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
}
} else if (bits == 6) {
for (int i = 0; i < (values_per_thread / 4); i++) {
uint8_t w0 = w[3 * i];
uint8_t w1 = w[3 * i + 1];
uint8_t w2 = w[3 * i + 2];
result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
result[4 * i + 1] +=
x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
result[4 * i + 2] +=
x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
}
}
else if (bits == 8) {
@@ -218,8 +371,8 @@ template <typename U, int N, int bits>
inline void
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
if (bits == 2) {
U s[4] = {
@@ -235,6 +388,22 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
}
}
else if (bits == 3) {
for (int i = 0; i < (N / 8); i++) {
w_local += 8 * i;
w += 3 * i;
w_local[0] = (w[0] & 0x7) * scale + bias;
w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
}
}
else if (bits == 4) {
U s[2] = {scale, scale / static_cast<U>(16.0f)};
for (int i = 0; i < (N / 2); i++) {
@@ -243,6 +412,18 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
}
}
else if (bits == 6) {
for (int i = 0; i < (N / 4); i++) {
w_local += 4 * i;
w += 3 * i;
w_local[0] = (w[0] & 0x3f) * scale + bias;
w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
}
}
else if (bits == 8) {
for (int i = 0; i < N; i++) {
w_local[i] = scale * w[i] + bias;
@@ -267,10 +448,11 @@ struct QuantizedBlockLoader {
group_size % BCOLS == 0,
"The group size should be divisible by the columns");
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
MLX_MTL_CONST short pack_factor = 32 / bits;
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
MLX_MTL_CONST short n_reads =
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
@@ -286,12 +468,12 @@ struct QuantizedBlockLoader {
const short bj;
threadgroup T* dst;
const device uint32_t* src;
const device uint8_t* src;
const device T* scales;
const device T* biases;
QuantizedBlockLoader(
const device uint32_t* src_,
const device uint8_t* src_,
const device T* scales_,
const device T* biases_,
const int src_ld_,
@@ -300,14 +482,16 @@ struct QuantizedBlockLoader {
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(
reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
reduction_dim ? BCOLS_PACKED * bytes_per_pack
: BROWS * src_ld * bytes_per_pack / pack_factor),
group_step_cnt(0),
group_stride(BROWS * src_ld / group_size),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(n_reads * thread_idx / BCOLS_PACKED),
bj((n_reads * thread_idx) % BCOLS_PACKED),
dst(dst_ + bi * dst_ld + bj * pack_factor),
src(src_ + bi * src_ld / pack_factor + bj),
src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
bj * bytes_per_pack),
scales(scales_ + bi * src_ld / group_size),
biases(biases_ + bi * src_ld / group_size) {}
@@ -320,7 +504,7 @@ struct QuantizedBlockLoader {
T bias = *biases;
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>(
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);
}
}
@@ -347,7 +531,10 @@ struct QuantizedBlockLoader {
T bias = *biases;
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>(
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
(device uint8_t*)(src + i * bytes_per_pack),
scale,
bias,
dst + i * pack_factor);
}
}
@@ -410,8 +597,7 @@ METAL_FUNC void qmv_quad_impl(
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_quadgroup; row++) {
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
@@ -442,25 +628,30 @@ METAL_FUNC void qmv_fast_impl(
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int packs_per_thread = bits > 2 ? 2 : 1;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = 32 / bits;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread;
@@ -470,8 +661,7 @@ METAL_FUNC void qmv_fast_impl(
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
@@ -480,7 +670,7 @@ METAL_FUNC void qmv_fast_impl(
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
w += block_size / pack_factor;
ws += block_size * bytes_per_pack / pack_factor;
scales += block_size / group_size;
biases += block_size / group_size;
x += block_size;
@@ -506,21 +696,25 @@ METAL_FUNC void qmv_impl(
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int packs_per_thread = 1;
constexpr int pack_factor = 32 / bits;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
@@ -533,7 +727,8 @@ METAL_FUNC void qmv_impl(
// In this case we need to properly guard all our reads because there isn't
// even 1 tile in the matrix
if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
ws +=
out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread;
@@ -544,8 +739,7 @@ METAL_FUNC void qmv_impl(
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; out_row + row < out_vec_size; row++) {
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
@@ -555,7 +749,7 @@ METAL_FUNC void qmv_impl(
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
w += block_size / pack_factor;
ws += block_size * bytes_per_pack / pack_factor;
scales += block_size / group_size;
biases += block_size / group_size;
x += block_size;
@@ -569,8 +763,7 @@ METAL_FUNC void qmv_impl(
x, x_thread, remaining);
for (int row = 0; out_row + row < out_vec_size; row++) {
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
@@ -591,7 +784,8 @@ METAL_FUNC void qmv_impl(
// In this case the last tile is moved back to redo some output values
else {
w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
ws += used_out_row * in_vec_size_w +
simd_lid * packs_per_thread * bytes_per_pack;
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread;
@@ -602,8 +796,7 @@ METAL_FUNC void qmv_impl(
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
@@ -613,7 +806,7 @@ METAL_FUNC void qmv_impl(
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
w += block_size / pack_factor;
ws += block_size * bytes_per_pack / pack_factor;
scales += block_size / group_size;
biases += block_size / group_size;
x += block_size;
@@ -627,8 +820,7 @@ METAL_FUNC void qmv_impl(
x, x_thread, remaining);
for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
@@ -659,14 +851,18 @@ METAL_FUNC void qvm_impl(
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int num_simdgroups = 2;
constexpr int pack_factor = 32 / bits;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int tn = 32 / pack_factor;
constexpr int blocksize = SIMD_SIZE;
constexpr int block_size = SIMD_SIZE;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U;
typedef struct {
uint32_t wi[tn];
uint8_t wi[tn * bytes_per_pack];
} vec_w;
thread vec_w w_local;
@@ -676,11 +872,10 @@ METAL_FUNC void qvm_impl(
thread U x_local = 0;
// Adjust positions
const int out_vec_size_w = out_vec_size / pack_factor;
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
const int out_vec_size_g = out_vec_size / group_size;
int out_col =
tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
w += out_col / pack_factor + simd_lid * out_vec_size_w;
int out_col = pack_factor * tn * (tid.x * num_simdgroups + simd_gid);
ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
scales += out_col / group_size + simd_lid * out_vec_size_g;
biases += out_col / group_size + simd_lid * out_vec_size_g;
x += tid.y * in_vec_size + simd_lid;
@@ -690,43 +885,42 @@ METAL_FUNC void qvm_impl(
return;
}
// Loop over in_vec in blocks of blocksize
int remaining = in_vec_size % blocksize;
// Loop over in_vec in blocks of block_size
int remaining = in_vec_size % block_size;
if (remaining == 0) {
for (int i = 0; i < in_vec_size; i += blocksize) {
for (int i = 0; i < in_vec_size; i += block_size) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)w);
w_local = *((device vec_w*)ws);
qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
x += blocksize;
scales += blocksize * out_vec_size_g;
biases += blocksize * out_vec_size_g;
w += blocksize * out_vec_size_w;
x += block_size;
scales += block_size * out_vec_size_g;
biases += block_size * out_vec_size_g;
ws += block_size * out_vec_size_w;
}
} else {
for (int i = blocksize; i < in_vec_size; i += blocksize) {
for (int i = block_size; i < in_vec_size; i += block_size) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)w);
w_local = *((device vec_w*)ws);
qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
x += blocksize;
scales += blocksize * out_vec_size_g;
biases += blocksize * out_vec_size_g;
w += blocksize * out_vec_size_w;
x += block_size;
scales += block_size * out_vec_size_g;
biases += block_size * out_vec_size_g;
ws += block_size * out_vec_size_w;
}
if (static_cast<int>(simd_lid) < remaining) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)w);
w_local = *((device vec_w*)ws);
} else {
x_local = 0;
scale = 0;
@@ -781,8 +975,9 @@ METAL_FUNC void qmm_t_impl(
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = 32 / bits;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::
@@ -800,13 +995,15 @@ METAL_FUNC void qmm_t_impl(
bits>;
// Set the block
const int K_w = K / pack_factor;
const int K_w = K * bytes_per_pack / pack_factor;
const int K_g = K / group_size;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
auto wl = (const device uint8_t*)w;
x += y_row * K;
w += y_col * K_w;
wl += y_col * K_w;
scales += y_col * K_g;
biases += y_col * K_g;
y += y_row * N + y_col;
@@ -815,7 +1012,7 @@ METAL_FUNC void qmm_t_impl(
const short num_els = min(BM, M - y_row);
const short num_outs = min(BN, N - y_col);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid);
loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
if (num_els < BM) {
@@ -857,6 +1054,7 @@ METAL_FUNC void qmm_t_impl(
loader_x.load_unsafe();
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
@@ -902,9 +1100,11 @@ METAL_FUNC void qmm_n_impl(
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = 32 / bits;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::
@@ -921,11 +1121,13 @@ METAL_FUNC void qmm_n_impl(
group_size,
bits>;
auto wl = (const device uint8_t*)w;
// Set the block
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
x += y_row * K;
w += y_col / pack_factor;
wl += y_col * bytes_per_pack / pack_factor;
scales += y_col / group_size;
biases += y_col / group_size;
y += y_row * N + y_col;
@@ -933,7 +1135,7 @@ METAL_FUNC void qmm_n_impl(
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid);
loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
if (num_els < BM) {
@@ -1805,13 +2007,14 @@ template <typename T, const int group_size, const int bits>
uint2 grid_dim [[threads_per_grid]]) {
constexpr T eps = T(1e-7);
constexpr int simd_size = 32;
constexpr int uint8_bits = 8;
constexpr T n_bins = (1 << bits) - 1;
constexpr int packs_per_int = uint8_bits / bits;
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int values_per_reduce = group_size / simd_size;
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
constexpr int writes_per_pack =
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
static_assert(
group_size % simd_size == 0,
@@ -1819,7 +2022,9 @@ template <typename T, const int group_size, const int bits>
size_t offset = index.x + grid_dim.x * size_t(index.y);
size_t in_index = offset * values_per_reduce;
size_t out_index = offset * writes_per_pack;
size_t out_index = power_of_2_bits
? offset * writes_per_pack
: offset * bytes_per_pack / writes_per_reduce;
T w_thread[values_per_reduce];
T w_min = Limits<T>::max;
@@ -1852,7 +2057,9 @@ template <typename T, const int group_size, const int bits>
biases[gindex] = bias;
}
uint8_t output = 0;
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
uint32_t output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
@@ -1868,47 +2075,23 @@ template <typename T, const int group_size, const int bits>
output = 0;
} else {
#pragma clang loop unroll(full)
for (int j = 0; j < writes_per_reduce - 1; j++) {
uint8_t sval = simd_shuffle_down(val, j + 1);
output += sval << (bits * (values_per_reduce + j + i));
for (int j = 1; j < writes_per_reduce; j++) {
uint8_t sval = simd_shuffle_down(val, j);
output += sval << (bits * (j * values_per_reduce + i));
}
}
}
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
out[out_index / writes_per_reduce] = output;
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize_scales_biases(
const device T* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
device uint8_t* out [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr int uint8_bits = 8;
constexpr int packs_per_int = uint8_bits / bits;
constexpr T n_bins = (1 << bits) - 1;
size_t offset = index.x + grid_dim.x * size_t(index.y);
size_t in_index = offset * packs_per_int;
size_t gindex = in_index / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
uint8_t output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
if (bits == 8) {
output = val;
} else {
output += val << (bits * i);
if (bits == 3 || bits == 6) {
if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
out[out_index] = output & 0xff;
out[out_index + 1] = (output & 0xff00) >> 8;
out[out_index + 2] = (output & 0xff0000) >> 16;
}
} else {
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
out[out_index / writes_per_reduce] = output;
}
}
out[offset] = output;
}
template <typename T, const int group_size, const int bits>
@@ -1919,26 +2102,48 @@ template <typename T, const int group_size, const int bits>
device T* out [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr int uint8_bits = 8;
constexpr int packs_per_int = uint8_bits / bits;
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
size_t offset = index.x + grid_dim.x * size_t(index.y);
size_t oindex = offset * packs_per_int;
size_t gindex = oindex / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
uint val = w[offset];
out += oindex;
if (bits == 3) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x7) * scale + bias;
out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
} else if (bits == 6) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x3f) * scale + bias;
out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
} else {
uint val = w[offset];
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
uint8_t d;
if (bits == 2) {
d = (val >> (bits * i)) & 0x03;
} else if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
for (int i = 0; i < packs_per_int; i++) {
uint8_t d;
if (bits == 2) {
d = (val >> (bits * i)) & 0x03;
} else if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[i] = scale * d + bias;
}
out[oindex + i] = scale * d + bias;
}
}

View File

@@ -72,7 +72,6 @@
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_quantize_scales_biases, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \
@@ -116,7 +115,9 @@
#define instantiate_quantized_all() \
instantiate_quantized_groups(2) \
instantiate_quantized_groups(3) \
instantiate_quantized_groups(4) \
instantiate_quantized_groups(6) \
instantiate_quantized_groups(8)
instantiate_quantized_all() // clang-format on

View File

@@ -1,930 +1,11 @@
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/backend/metal/kernels/sdpa_vector.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
using namespace mlx::steel;
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoaderFA {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC BlockLoaderFA(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
*((const device ReadVector*)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
METAL_FUNC void next(short n) {
src += n * tile_stride;
}
};
template <bool M_aligned, bool N_aligned, bool K_aligned>
struct LoopAlignment {};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType = float,
typename Epilogue = TransformNone<U, AccumType>>
struct BlockMMAFA {
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = 8 * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
// Strides of A, B along reduction axis
STEEL_CONST short simd_stride_a = {
transpose_a ? TM_stride : TM_stride * lda_tgp};
STEEL_CONST short simd_stride_b = {
transpose_b ? TN_stride * ldb_tgp : TN_stride};
// Jump between elements
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
// Offsets within threadgroup
const short tm;
const short tn;
short sm;
short sn;
ushort sid;
ushort slid;
short As_offset;
short Bs_offset;
/* Constructor */
METAL_FUNC BlockMMAFA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
// Determine thread position in simdgroup matrix
short qid = simd_lane_id / 4;
slid = simd_lane_id;
sid = simd_group_id;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Determine thread and simdgroup offset
As_offset =
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
Bs_offset =
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of 8
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += 8) {
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] =
static_cast<AccumType>(As[i * simd_stride_a + 0]);
Asimd[i].thread_elements()[1] =
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] =
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
Bsimd[j].thread_elements()[1] =
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
short j_serp = (i % 2) ? (TN - 1 - j) : j;
simdgroup_multiply_accumulate(
results[i * TN + j_serp],
Asimd[i],
Bsimd[j_serp],
results[i * TN + j_serp]);
}
}
// Progress to next simdgroup tile
As += tile_stride_a;
Bs += tile_stride_b;
}
}
METAL_FUNC void rescale_output(const threadgroup float* Corrections) {
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
short row = sm + tm + i * TM_stride;
float scale_value = Corrections[row];
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = results[i * TN + j].thread_elements();
// int offset = (i * TM_stride) * ldc + (j * TN_stride);
accum[0] *= scale_value;
accum[1] *= scale_value;
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* C, const int ldc) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
// Write out C
C[offset] = outs[0];
C[offset + 1] = outs[1];
}
}
}
METAL_FUNC void store_result_to_tgp_memory(
threadgroup U* C,
const int ldc,
short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
METAL_FUNC void
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
U outs[2] = {
epilogue_op.apply(accum[0], C[offset_c]),
epilogue_op.apply(accum[1], C[offset_c + fdc])};
// Write out D
D[offset_d] = outs[0];
D[offset_d + 1] = outs[1];
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
}
}
}
}
}
METAL_FUNC void clear_results() {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
results[i * TN + j] = simdgroup_matrix<AccumType, 8, 8>(0);
}
}
}
};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_q,
bool transpose_k,
bool transpose_v,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<U, AccumType>>
struct FastAttentionKernel {
STEEL_CONST short tgp_padding = 16 / sizeof(T);
STEEL_CONST short float_padding = 16 / sizeof(float);
STEEL_CONST short tgp_mem_size_q =
transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding);
STEEL_CONST short tgp_mem_size_k =
transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
STEEL_CONST short tgp_mem_size_v =
transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding);
// maxes, rowsums, rescale
STEEL_CONST short tgp_mem_size_corrections =
4 * (BM * sizeof(float) + float_padding);
STEEL_CONST bool share_kv_smem = transpose_k != transpose_v;
STEEL_CONST short tgp_mem_size = share_kv_smem
? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
tgp_mem_size_corrections
: tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
tgp_mem_size_corrections + tgp_mem_size_v;
STEEL_CONST short tgp_size = WM * WN * 32;
static_assert(transpose_q == false, "Expected Q not transposed.");
static_assert(transpose_k == true, "Expected K transposed.");
static_assert(transpose_v == false, "Expected V not transposed.");
static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested.");
using loader_q_t = BlockLoaderFA<
T,
transpose_q ? BK : BM,
transpose_q ? BM : BK,
transpose_q ? BM + tgp_padding : BK + tgp_padding,
!transpose_q,
tgp_size>;
using loader_k_t = BlockLoaderFA<
T,
transpose_k ? BN : BK,
transpose_k ? BK : BN,
transpose_k ? BK + tgp_padding : BN + tgp_padding,
transpose_k,
tgp_size>;
using loader_v_t = BlockLoaderFA<
T,
transpose_v ? BK : BN,
transpose_v ? BN : BK,
transpose_v ? BN + tgp_padding : BK + tgp_padding,
transpose_v,
tgp_size>;
using mma_qk_t = BlockMMAFA<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_q,
transpose_k,
transpose_q ? BM + tgp_padding : BK + tgp_padding,
transpose_k ? BK + tgp_padding : BN + tgp_padding,
AccumType,
Epilogue>;
using mma_sv_t = BlockMMAFA<
T,
U,
BM,
BK,
BN,
WM,
WN,
false,
transpose_v,
BN + tgp_padding,
BK + tgp_padding,
AccumType,
Epilogue>;
/* Main kernel function */
template <bool M_aligned, bool N_aligned, bool K_aligned_>
static METAL_FUNC void gemm_loop(
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
const int gemm_k_iterations,
thread loader_k_t& loader_b,
thread mma_qk_t& mma_op,
thread const short& tgp_bm,
thread const short& tgp_bn,
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
// Appease the compiler
(void)l;
(void)tgp_bm;
short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
// not valid for gemm_k_iterations > 1 (so, BK == d_k)
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
}
static METAL_FUNC void initialize_corrections(
threadgroup float* C,
uint simd_lane_id,
uint simd_group_id) {
if (simd_group_id == 0) {
threadgroup float* maxes = C;
threadgroup float* sums = C + (BM + float_padding);
threadgroup float* o_rescale = sums + (BM + float_padding);
threadgroup float* output_rescale = o_rescale + (BM + float_padding);
if (simd_lane_id < BM) {
maxes[simd_lane_id] = -INFINITY; // m_i
sums[simd_lane_id] = 0.f; // l_i
o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new)
output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i
}
}
}
static METAL_FUNC void rescale_ss(
threadgroup T* Ss,
threadgroup float* Corrections,
uint simd_group_id,
uint simd_lane_id,
short2 local_blocks,
float alpha) {
if (simd_group_id == 0) {
short row_offset = BM + float_padding;
threadgroup float* maxes = Corrections;
threadgroup float* sums = Corrections + row_offset;
threadgroup float* o_rescale = sums + row_offset;
threadgroup float* output_scales = o_rescale + row_offset;
if (simd_lane_id < uint(local_blocks.y)) {
float m_i_old = maxes[simd_lane_id];
float l_i_old = sums[simd_lane_id];
float m_i_new = m_i_old;
float l_i_new = l_i_old;
short offset = simd_lane_id * (BN + tgp_padding);
float m_ij = -INFINITY;
for (short j = 0; j < local_blocks.x; j++) {
float val = alpha * float(Ss[offset + j]);
m_ij = max(m_ij, val);
}
m_i_new = max(m_ij, m_i_new);
float rowsum = 0.f; // lij
for (short j = 0; j < local_blocks.x; j++) {
float val = alpha * float(Ss[offset + j]);
float P_i_j = exp(val - m_ij);
rowsum += P_i_j;
P_i_j = P_i_j * exp(m_ij - m_i_new);
Ss[offset + j] = T(P_i_j);
}
l_i_new =
exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum;
maxes[simd_lane_id] = m_i_new;
sums[simd_lane_id] = l_i_new;
float rescale = l_i_old * exp(m_i_old - m_i_new);
o_rescale[simd_lane_id] = rescale;
output_scales[simd_lane_id] = 1.0 / l_i_new;
}
}
}
/* Main kernel function */
static METAL_FUNC void run(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device U* O [[buffer(3)]],
const constant MLXFastAttentionParams* params [[buffer(4)]],
threadgroup T* Qs [[threadgroup(0)]],
threadgroup T* Ks [[threadgroup(1)]],
threadgroup T* Ss [[threadgroup(2)]],
threadgroup T* Vs [[threadgroup(3)]],
threadgroup float* Corrections [[threadgroup(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in Q, O; and head in K, V.
const int c_row = tid_y * BM;
Q += transpose_q ? c_row : c_row * params->ldq;
thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id);
short tgp_bm = min(BM, params->M - c_row);
short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
loader_q.load_safe(tile_dims_Q);
initialize_corrections(Corrections, simd_lane_id, simd_group_id);
O += c_row * params->ldo;
// Prepare threadgroup mma operation
thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id);
thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id);
thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id);
thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id);
for (short n_block = 0; n_block < params->gemm_n_iterations_aligned;
n_block++) {
short c_col = BN;
// Prepare threadgroup loading operations
short gemm_k_iterations = params->gemm_k_iterations_aligned;
short tgp_bn_qk = min(BN, params->N - c_col * n_block);
threadgroup_barrier(mem_flags::mem_none);
///////////////////////////////////////////////////////////////////////////////
{ // Loop over K - unaligned case
if (tgp_bm == BM && tgp_bn_qk == BN) {
gemm_loop<true, true, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
} else if (tgp_bn_qk == BN) {
gemm_loop<false, true, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
} else if (tgp_bm == BM) {
gemm_loop<true, false, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
} else {
gemm_loop<false, false, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
}
}
mma_qk_op.store_result_to_tgp_memory(
Ss, BN + tgp_padding, short2(BN, BM));
threadgroup_barrier(mem_flags::mem_threadgroup);
rescale_ss(
Ss,
Corrections,
simd_group_id,
simd_lane_id,
short2(tgp_bn_qk, tgp_bm),
params->alpha);
loader_v.load_safe(short2(BK, tgp_bn_qk));
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float* o_scales = Corrections + 2 * (BM + float_padding);
mma_softmax_sv_op.rescale_output(o_scales);
mma_softmax_sv_op.mma(Ss, Vs);
threadgroup float* final_output_scales =
Corrections + 3 * (BM + float_padding);
mma_softmax_sv_op.rescale_output(final_output_scales);
loader_v.next();
loader_k.next(BN);
mma_qk_op.clear_results();
}
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm));
}
};
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_q,
bool transpose_k,
bool transpose_v,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device T* O [[buffer(3)]],
const constant MLXFastAttentionParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using attention_kernel = FastAttentionKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_q,
transpose_k,
transpose_v,
MN_aligned,
K_aligned>;
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* Q_bstrides = batch_strides;
const constant size_t* KV_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim);
Q += batch_offsets.x;
K += batch_offsets.y;
V += batch_offsets.y;
} else {
Q += params->batch_stride_q * tid.z;
K += params->batch_stride_k * tid.z;
V += params->batch_stride_v * tid.z;
}
// same shape as input
O += params->batch_stride_o * tid.z;
threadgroup T Qs[attention_kernel::tgp_mem_size_q];
threadgroup T Ss[attention_kernel::tgp_mem_size_s];
threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections];
if (attention_kernel::share_kv_smem) {
threadgroup T Ks[attention_kernel::tgp_mem_size_k];
threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v];
attention_kernel::run(
Q,
K,
V,
O,
params,
Qs,
Ks,
Ss,
Vs,
Corrections,
simd_lane_id,
simd_group_id,
tid,
lid);
} else {
threadgroup T Ks[attention_kernel::tgp_mem_size_k];
threadgroup T Vs[attention_kernel::tgp_mem_size_v];
attention_kernel::run(
Q,
K,
V,
O,
params,
Qs,
Ks,
Ss,
Vs,
Corrections,
simd_lane_id,
simd_group_id,
tid,
lid);
}
}
// clang-format off
// SDPA full instantiations
#define instantiate_fast_inference_self_attention_kernel( \
itype, otype, bm, bn, bk, wm, wn) \
template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \
"_itype_" #itype)]] [[kernel]] void \
attention<itype, bm, bn, bk, wm, wn, false, true, false, false, true>( \
const device itype* Q [[buffer(0)]], \
const device itype* K [[buffer(1)]], \
const device itype* V [[buffer(2)]], \
device otype* O [[buffer(3)]], \
const constant MLXFastAttentionParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
instantiate_fast_inference_self_attention_kernel(
float,
float,
16,
16,
64,
2,
2);
instantiate_fast_inference_self_attention_kernel(
float,
float,
16,
16,
128,
2,
2);
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2);
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
// SDPA vector instantiations
#define instantiate_sdpa_vector(type, head_dim) \
instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) \

View File

@@ -1,42 +0,0 @@
//
// scaled_dot_product_attention_params.h
// mlx
#pragma once
struct MLXFastAttentionParams {
const int M;
const int N;
const int K;
const int ldq; // ldq == ldo
const int ldk;
const int ldv;
const int lds;
const int ldo;
const int tiles_n;
const int tiles_m;
const int batch_stride_q;
const int batch_stride_k;
const int batch_stride_v;
const int batch_stride_o;
const int swizzle_log;
const int gemm_n_iterations_aligned;
const int gemm_k_iterations_aligned;
const int gemm_sv_m_block_iterations;
const int batch_ndim;
const float alpha;
};
struct MLXScaledDotProductAttentionParams {
// Associated dimensions & transposition information
const uint QUERY_SEQUENCE_LENGTH = 1;
const uint N_Q_HEADS = 32;
const uint N_KV_HEADS = 32;
const uint KV_TILES = 1;
const float INV_ALPHA = 0.08838834764831843f;
};

View File

@@ -0,0 +1,296 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/attn/loader.h"
#include "mlx/backend/metal/kernels/steel/attn/mma.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel class
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <bool M_aligned, bool N_aligned, bool K_aligned>
struct LoopAlignment {};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<U, AccumType>>
struct GEMMKernel {
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
STEEL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
STEEL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
STEEL_CONST short tgp_size = WM * WN * 32;
using loader_a_t = BlockLoader<
T,
transpose_a ? BK : BM,
transpose_a ? BM : BK,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
!transpose_a,
tgp_size>;
using loader_b_t = BlockLoader<
T,
transpose_b ? BN : BK,
transpose_b ? BK : BN,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
transpose_b,
tgp_size>;
using mma_t = BlockMMA<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
template <bool M_aligned, bool N_aligned, bool K_aligned_>
static METAL_FUNC void gemm_loop(
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
const int gemm_k_iterations,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
thread mma_t& mma_op,
thread const short& tgp_bm,
thread const short& tgp_bn,
thread const short& lbk,
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
// Appease the compiler
(void)l;
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned_) {
threadgroup_barrier(mem_flags::mem_threadgroup);
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device U* D [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_loop<true, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result(D, params->ldd);
return;
} else if (tgp_bn == BN) {
gemm_loop<false, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
} else if (tgp_bm == BM) {
gemm_loop<true, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
} else {
gemm_loop<false, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@@ -0,0 +1,349 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
constant bool align_Q [[function_constant(200)]];
constant bool align_K [[function_constant(201)]];
template <typename T>
struct TransformScale {
T scale;
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
METAL_FUNC T apply(T x) const {
return scale * x;
}
};
struct MaxOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return metal::max(x, y);
}
};
struct SumOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x + y;
}
};
struct MulOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x * y;
}
};
struct SubOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x - y;
}
};
struct ExpSubOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return fast::exp(x - y);
}
};
struct DivOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x / y;
}
};
// clang-format off
template <
typename T,
int BQ,
int BK,
int BD,
int WM,
int WN,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device T* O [[buffer(3)]],
const constant AttnParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
// Pacifying compiler
(void)lid;
// Move to correct block
ulong3 tidl{tid.x, tid.y, tid.z};
Q += tidl.z * params->Q_strides[0] + // Batch
tidl.y * params->Q_strides[1] + // Head
tidl.x * BQ * params->Q_strides[2]; // Seqeunce
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
K += tidl.z * params->K_strides[0] + // Batch
kv_head_idx * params->K_strides[1]; // Head
V += tidl.z * params->V_strides[0] + // Batch
kv_head_idx * params->V_strides[1]; // Head
O += tidl.z * params->O_strides[0] + // Batch
tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Seqeunce
// Prepare threadgroup memory
constexpr short padQ = 0; // 16 / sizeof(T);
constexpr short padK = 0; // 16 / sizeof(T);
constexpr short padV = 0; // 16 / sizeof(T);
constexpr short LDQ_tgp = BD + padQ;
constexpr short LDK_tgp = BK + padK;
constexpr short LDV_tgp = BD + padV;
threadgroup T Qs[BQ * (BD + padQ)];
threadgroup T Ks[(BK + padK) * BD];
threadgroup T Vs[BK * (BD + padV)];
// Prepare block loaders
using QBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BQ,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ LDQ_tgp,
/* short kDstStrCol = */ 1,
/* short reduction_dim = */ 1,
/* short tgp_size = */ WM * WN * 32>;
// K is loaded in transposed
using KBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BK,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ 1,
/* short kDstStrCol = */ LDK_tgp,
/* short reduction_dim = */ 0,
/* short tgp_size = */ WM * WN * 32>;
using VBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BK,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ LDV_tgp,
/* short kDstStrCol = */ 1,
/* short reduction_dim = */ 0,
/* short tgp_size = */ WM * WN * 32>;
QBlockLoader loader_q(
Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
KBlockLoader loader_k(
K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
VBlockLoader loader_v(
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
TransformScale<T> ts(static_cast<T>(params->scale));
// Prepare MMA tiles
constexpr short kFragSize = 8; // MMAFrag size
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
constexpr int kNWarps = WM * WN;
static_assert(
BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
// Q seq frags per warp
constexpr int TQ = BQ / (kNWarps * kFragSize);
// KV sequence frags (all warps load the same frags)
constexpr int TK = BK / kFragSize;
// HeadDim frags (all warps load the same frags)
constexpr int TD = BD / kFragSize;
static_assert(TQ == 1, "Check TQ");
MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;
MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;
MMATile<AccumType, TK, TD, MMAFrag_acc_t> Vtile;
MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;
Otile.clear();
// Prepare mma tile offsets
const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
const short sm = simd_coord.y;
const short sn = simd_coord.x;
const short tm = kFragSize * TQ * simd_group_id;
const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
const short Ks_offset = sm * LDK_tgp + sn;
const short Vs_offset = sm * LDV_tgp + sn;
constexpr short Qs_tile_stride = kFragSize;
constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load Q blocks apply scale
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
} else {
loader_q.load_unsafe();
}
loader_q.apply_inplace_op(ts);
// Init row reduction variables
constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
AccumType max_score[kRowsPT];
AccumType sum_score[kRowsPT] = {0};
// Init to -Inf
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::min;
}
// Loop over KV seq length
for (int kb = 0; kb < params->NK; kb++) {
// Load K block and apply scale
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!align_K && kb == (params->NK_aligned)) {
loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
} else {
loader_k.load_unsafe();
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do S = Q @ K.T
Stile.clear();
for (short dd = 0; dd < TD; dd++) {
simdgroup_barrier(mem_flags::mem_none);
Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
&Qs[Qs_offset + dd * Qs_tile_stride]);
Ktile.template load<T, 1, 1, LDK_tgp, 1>(
&Ks[Ks_offset + dd * Ks_tile_stride]);
simdgroup_barrier(mem_flags::mem_none);
tile_matmad(Stile, Qtile, Ktile, Stile);
}
// Mask out of length sequence
if (!align_K && kb == (params->NK_aligned)) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
const short lim = params->kL - params->NK_aligned * BK;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < stile_t::kTileCols; j++) {
short col_pos = sn + (j * stile_t::kFragCols);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
if ((col_pos + jj) >= lim) {
Stile.frag_at(i, j)[jj] = neg_inf;
}
}
}
}
}
simdgroup_barrier(mem_flags::mem_none);
// Load V blocks
if (!align_K && kb == (params->NK_aligned)) {
loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
} else {
loader_v.load_unsafe();
}
// Do softmax
// Temp variables
AccumType new_max[kRowsPT];
AccumType factor[kRowsPT];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
new_max[i] = max_score[i];
}
// Row max
Stile.template row_reduce<MaxOp>(new_max);
// exp(Si - rowmax(Si))
Stile.template row_bin_op<ExpSubOp>(new_max);
// Factor exp(rowmax(Si) - rowmax(Si-1))
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
factor[i] = fast::exp(max_score[i] - new_max[i]);
}
// Save max for next iteration
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = new_max[i];
}
// Row Sum
AccumType sum_score_tmp[kRowsPT] = {0};
Stile.template row_reduce<SumOp>(sum_score_tmp);
// Update norm
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
}
// Update O
Otile.template row_bin_op<MulOp>(factor);
// Load V into registers
threadgroup_barrier(mem_flags::mem_threadgroup);
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
simdgroup_barrier(mem_flags::mem_none);
// Do O = S @ V
tile_matmad(Otile, Stile, Vtile, Otile);
// Prepare for next iteration
loader_k.next();
loader_v.next();
}
// Normalize output
Otile.template row_bin_op<DivOp>(sum_score);
threadgroup_barrier(mem_flags::mem_none);
// Store results
O += (tm + sm) * params->O_strides[2] + sn;
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
auto dst_tile_dims =
short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
} else {
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
}
}

View File

@@ -0,0 +1,31 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
[[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
const device dtype* Q [[buffer(0)]], \
const device dtype* K [[buffer(1)]], \
const device dtype* V [[buffer(2)]], \
device dtype* O [[buffer(3)]],\
const constant AttnParams* params [[buffer(4)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_attn_shapes_helper(iname, itype) \
instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
instantiate_attn_shapes_helper(float16, half);
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
instantiate_attn_shapes_helper(float32, float);
// clang-format on

View File

@@ -0,0 +1,264 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/defines.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoader {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC BlockLoader(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
*((const device ReadVector*)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
template <int R, int C>
struct CShape {
STEEL_CONST int kRows = R;
STEEL_CONST int kCols = C;
};
template <
typename T,
short BROWS,
short BCOLS,
short kDstStrRow,
short kDstStrCol,
short reduction_dim,
short tgp_size,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoaderT {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
/* Constructor */
METAL_FUNC BlockLoaderT(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * kDstStrRow + bj * kDstStrCol),
src(src_ + bi * src_ld + bj) {}
/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] =
op.apply(dst[i * kDstStrRow + j * kDstStrCol]);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j];
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
} // namespace steel
} // namespace mlx

View File

@@ -0,0 +1,726 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename RInt, typename CInt>
struct Shape2D {
RInt r;
CInt c;
Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {}
};
template <typename Shape, typename Layout>
struct Layout2D {
Shape shape;
Layout layout;
};
template <typename T, int kFragRows_, int kFragCols_>
struct BaseMMAFrag {
static_assert(
kFragRows_ == 8,
"Only 8 x 8 fragment matrices are currently supported");
static_assert(
kFragCols_ == 8,
"Only 8 x 8 fragment matrices are currently supported");
};
template <typename T>
struct BaseMMAFrag<T, 8, 8> {
STEEL_CONST int kFragRows = 8;
STEEL_CONST int kFragCols = 8;
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
STEEL_CONST int kElemRows = 1;
STEEL_CONST int kElemCols = 2;
static_assert(
kElemRows * kElemCols == kElemsPerFrag,
"MMAFrag shape is not consistent with MMAFrag size");
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
typedef metal::vec<T, kElemsPerFrag> frag_type;
typedef metal::vec<T, kElemRows> row_frag_type;
typedef metal::vec<T, kElemCols> col_frag_type;
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
[[thread_index_in_simdgroup]]) {
const short qid = simd_lane_id / 4;
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
return short2{fn, fm};
}
template <typename SrcPtrType, typename StrX, typename StrY>
METAL_FUNC static constexpr void
load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
}
}
}
template <
typename SrcPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void load_safe(
thread frag_type& dst,
SrcPtrType src,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[i * kElemCols + j] =
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
} else {
dst[i * kElemCols + j] = T(0);
}
}
}
}
template <typename DstPtrType, typename StrX, typename StrY>
METAL_FUNC static constexpr void
store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
}
}
}
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void store_safe(
const thread frag_type& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
METAL_FUNC static constexpr void mma(
thread frag_type& D,
thread frag_type& A,
thread frag_type& B,
thread frag_type& C) {
mat_type D_mat;
mat_type A_mat;
mat_type B_mat;
mat_type C_mat;
reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
mma(D_mat, A_mat, B_mat, C_mat);
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
}
METAL_FUNC static constexpr void mma(
thread mat_type& D,
thread mat_type& A,
thread mat_type& B,
thread mat_type& C) {
simdgroup_multiply_accumulate(D, A, B, C);
}
template <typename Op>
METAL_FUNC static constexpr void row_reduce(
thread const frag_type& inp_vals,
thread T* reduced_vals) {
T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
}
template <typename Op>
METAL_FUNC static constexpr void row_bin_op(
thread frag_type& inp_vals,
thread T* row_vals) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
inp_vals[i * kElemCols + j] =
Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
}
}
}
};
template <
typename T,
int kTileRows_,
int kTileCols_,
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
struct MMATile {
using MMAFrag_t = MMAFrag_;
using elem_type = T;
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
STEEL_CONST int kTileRows = kTileRows_;
STEEL_CONST int kTileCols = kTileCols_;
STEEL_CONST int kRows = kTileRows * kFragRows;
STEEL_CONST int kCols = kTileCols * kFragCols;
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;
STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;
typedef typename MMAFrag_t::mat_type mat_type;
typedef typename MMAFrag_t::frag_type frag_type;
frag_type val_frags[kNumFrags] = {frag_type(0)};
METAL_FUNC MMATile() thread {}
METAL_FUNC constexpr void clear() {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kNumFrags; ++i) {
val_frags[i] = frag_type(0);
}
}
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
return val_frags[i * kTileCols + j];
}
METAL_FUNC constexpr const thread frag_type& frag_at(
const short i,
const short j) const {
return val_frags[i * kTileCols + j];
}
METAL_FUNC mat_type mat_at(const short i, const short j) {
mat_type val_mat;
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
}
return val_mat;
}
METAL_FUNC thread elem_type* elems() {
return reinterpret_cast<thread elem_type*>(val_frags);
}
METAL_FUNC const thread elem_type* elems() const {
return reinterpret_cast<const thread elem_type*>(val_frags);
}
template <typename Op>
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::template row_reduce<Op>(
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
}
}
}
template <typename Op>
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::template row_bin_op<Op>(
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
}
}
}
template <typename U, int w_x, int w_y, int str_x, int str_y>
METAL_FUNC void load(const threadgroup U* src) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::load(
frag_at(i, j),
&(
src[(i * kFragRows) * w_x * str_x +
(j * kFragCols) * w_y * str_y]),
Int<str_x>{},
Int<str_y>{});
}
}
}
template <typename U, int w_x, int w_y, int str_x, int str_y>
METAL_FUNC void store(threadgroup U* dst) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::store(
frag_at(i, j),
&(
dst[(i * kFragRows) * w_x * str_x +
(j * kFragCols) * w_y * str_y]),
Int<str_x>{},
Int<str_y>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void load(const device U* src, const int ld) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::load(
frag_at(i, j),
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
ld,
Int<1>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void store(device U* dst, const int ld) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::store(
frag_at(i, j),
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
ld,
Int<1>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::load_safe(
frag_at(i, j),
src,
ld,
Int<1>{},
src_tile_dims.y,
src_tile_dims.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::store_safe(
frag_at(i, j),
dst,
ld,
Int<1>{},
dst_tile_dims.y,
dst_tile_dims.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
};
template <typename T, typename U, int M, int N, int K>
METAL_FUNC void tile_matmad(
thread MMATile<T, M, N>& D,
thread MMATile<U, M, K>& A,
thread MMATile<U, K, N>& B,
thread MMATile<T, M, N>& C) {
STEEL_PRAGMA_UNROLL
for (short k = 0; k < K; ++k) {
STEEL_PRAGMA_UNROLL
for (short m = 0; m < M; ++m) {
STEEL_PRAGMA_UNROLL
for (short n = 0; n < N; ++n) {
short n_serp = (m % 2) ? (N - 1 - n) : n;
MMATile<T, M, N>::MMAFrag_t::mma(
D.frag_at(m, n_serp),
A.frag_at(m, k),
B.frag_at(k, n_serp),
C.frag_at(m, n_serp));
}
}
}
}
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType = float,
typename Epilogue = TransformNone<U, AccumType>>
struct BlockMMA {
// MMAFrag size
STEEL_CONST short kFragSize = 8;
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = kFragSize * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = kFragSize * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
// Threadgroup A strides
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
// Threadgroup B strides
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
// Threadgroup strides along K
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
// Simdgroup matrices
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
// Offsets within threadgroup
short sm;
short sn;
short As_offset;
short Bs_offset;
/* Constructor */
METAL_FUNC BlockMMA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
// Determine thread position in simdgroup matrix
short tm = kFragSize * (simd_group_id / WN);
short tn = kFragSize * (simd_group_id % WN);
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
sm = simd_coord.y;
sn = simd_coord.x;
// Determine thread and simdgroup offset
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
sm += tm;
sn += tn;
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of kFragSize
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += kFragSize) {
simdgroup_barrier(mem_flags::mem_none);
Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
simdgroup_barrier(mem_flags::mem_none);
Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
simdgroup_barrier(mem_flags::mem_none);
tile_matmad(Ctile, Atile, Btile, Ctile);
// Progress to next simdgroup tile
As += tile_stride_a;
Bs += tile_stride_b;
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* D, const int ldd) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
// Adjust for simdgroup and thread location
D += sm * ldd + sn;
Ctile.template store<U, WM, WN>(D, ldd);
}
METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
// Adjust for simdgroup and thread location
D += sm * ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
}
/* Apply epilogue */
template <typename UnaryEpilogue>
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue(
const device U* C,
const int ldc,
const int fdc,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue_safe(
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
// Read C
U c_elems[kelems] = {0};
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
c_elems[k] = C[offset_c + k * fdc];
}
}
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
D[offset_d + k] =
epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@@ -0,0 +1,36 @@
// Copyright © 2024 Apple Inc.
#pragma once
///////////////////////////////////////////////////////////////////////////////
// Attn param classes
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
struct AttnParams {
int B; ///< Batch Size
int H; ///< Heads
int D; ///< Head Dim
int qL; ///< Query Sequence Length
int kL; ///< Key Sequence Length
int gqa_factor; ///< Group Query factor
float scale; ///< Attention scale
int NQ; ///< Number of query blocks
int NK; ///< Number of key/value blocks
int NQ_aligned; ///< Number of full query blocks
int NK_aligned; ///< Number of full key/value blocks
size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
size_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
size_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
size_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
};
} // namespace steel
} // namespace mlx

View File

@@ -0,0 +1,71 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
///////////////////////////////////////////////////////////////////////////////
// Transforms and Epilogues
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
static METAL_FUNC OutT apply(InT x, OutT) {
return static_cast<OutT>(x);
}
};
template <typename OutT, typename InT>
struct TransformAdd {
TransformAdd(const float, const float) {}
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
static METAL_FUNC OutT apply(InT x, OutT c) {
return static_cast<OutT>(x) + c;
}
};
template <typename OutT, typename InT>
struct TransformAxpby {
const float alpha;
const float beta;
TransformAxpby(const float alpha_, const float beta_)
: alpha(alpha_), beta(beta_) {}
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
struct BlockSwizzle {
static METAL_FUNC int2
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
const int tid_x = (tid.x) >> swizzle_log;
const int tid_y =
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
return int2(tid_x, tid_y);
}
};
} // namespace steel
} // namespace mlx

View File

@@ -385,9 +385,9 @@ struct BlockMMA {
STEEL_CONST short TN_stride = kFragSize * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
STEEL_CONST short TM = BM / (kFragSize * WM);
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
STEEL_CONST short TN = BN / (kFragSize * WN);
// Threadgroup A strides
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M

View File

@@ -10,6 +10,7 @@
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
@@ -298,7 +299,7 @@ void qmm_op(
bool quad = false;
if (transpose) {
if (B < 6 && (D == 128 || D == 64)) {
if (B < 6 && (D == 128 || D == 64) && is_power_of_2(bits)) {
name += "qmv_quad";
constexpr int quads_per_simd = 8;
constexpr int results_per_quadgroup = 8;
@@ -391,8 +392,6 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
bool compute_scale_bias = inputs.size() == 1;
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -415,7 +414,7 @@ void fast::AffineQuantize::eval_gpu(
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_input_array(w, 0);
if (!compute_scale_bias) {
if (dequantize_) {
auto& scales_pre = inputs[1];
auto& biases_pre = inputs[2];
auto scales = ensure_row_contiguous(scales_pre);
@@ -436,12 +435,7 @@ void fast::AffineQuantize::eval_gpu(
std::ostringstream kname;
auto type_string = dequantize_ ? get_type_string(out.dtype())
: get_type_string(w_pre.dtype());
auto kernel_func = "affine_quantize_scales_biases";
if (dequantize_) {
kernel_func = "affine_dequantize";
} else if (compute_scale_bias) {
kernel_func = "affine_quantize";
}
auto kernel_func = dequantize_ ? "affine_dequantize" : "affine_quantize";
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
auto template_def = get_template_definition(
@@ -452,10 +446,10 @@ void fast::AffineQuantize::eval_gpu(
// Treat uint32 as uint8 in kernel
constexpr int uint8_per_uint32 = 4;
constexpr int simd_size = 32;
int packs_per_int = 8 / bits_;
int per_thread = compute_scale_bias ? group_size_ / simd_size : packs_per_int;
int packs_per_int = bits_ == 3 ? 8 : bits_ == 6 ? 4 : 8 / bits_;
int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size;
size_t nthreads =
dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread;
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {

View File

@@ -6,7 +6,9 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/utils.h"
@@ -19,122 +21,89 @@ void sdpa_full_self_attention_metal(
const array& q,
const array& k,
const array& v,
const float alpha,
array& out) {
std::ostringstream kname_self_attention;
kname_self_attention << "steel_gemm_attention_";
const float scale,
array& o) {
using namespace mlx::steel;
constexpr const int bm = 16;
constexpr const int bn = 16;
const int bk = q.shape(-1); // already forced to be 64 or 128
int wm = 4;
int wn = 1;
if (bk != 64 && bk != 128) {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: hidden dim: expected either 64, 128");
}
int bd = q.shape(-1);
int bq = 32;
int bk = bd < 128 ? 32 : 16;
constexpr const int wm = 2;
constexpr const int wn = 2;
int B = q.shape(0);
int H = q.shape(1);
int D = q.shape(3);
int gqa_factor = q.shape(1) / k.shape(1);
std::string delimiter = "_";
int qL = q.shape(2);
int kL = k.shape(2);
kname_self_attention << "bm_" + std::to_string(bm) + delimiter;
kname_self_attention << "bn_" + std::to_string(bn) + delimiter;
kname_self_attention << "bk_" + std::to_string(bk) + delimiter;
const bool align_Q = (qL % bq) == 0;
const bool align_K = (kL % bk) == 0;
for (const auto& arr : {k, v, out}) {
if (arr.dtype() != q.dtype()) {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o");
}
}
metal::MTLFCList func_consts = {
{&align_Q, MTL::DataType::DataTypeBool, 200},
{&align_K, MTL::DataType::DataTypeBool, 201},
};
if (q.dtype() == float32) {
kname_self_attention << "itype" + delimiter + "float";
} else if (q.dtype() == float16) {
kname_self_attention << "itype" + delimiter + "half";
} else {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16.");
}
std::ostringstream kname;
// clang-format off
kname << "steel_attention_"
<< type_to_name(q)
<< "_bq" << bq
<< "_bk" << bk
<< "_bd" << bd
<< "_wm" << wm << "_wn" << wn; // clang-format on
std::string base_name = kname.str();
// clang-format off
kname << "_align_Q_" << (align_Q ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname_self_attention.str());
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
uint hidden_dim = q.shape(-1);
uint qseq = q.shape(-2);
uint qheads = q.shape(-3);
const int NQ = (qL + bq - 1) / bq;
const int NK = (kL + bk - 1) / bk;
const uint64_t KV_sequence_length = k.shape(-2);
const uint query_sequence_length = q.shape(-2);
const uint n_q_heads = q.shape(1);
const uint n_kv_heads = k.shape(1);
const int NQ_aligned = qL / bq;
const int NK_aligned = kL / bk;
const int M = q.shape(-2);
const int N = M;
const int K = q.shape(-1);
const size_t batch_size_out = q.shape(0) * q.shape(1);
AttnParams params{
/* int B = */ B,
/* int H = */ H,
/* int D = */ D,
const std::vector<int> batch_shape = {q.shape(0) * q.shape(1)};
const int dk = q.shape(-1);
const int ldq = dk;
const int ldk = dk;
const int ldv = dk;
const int lds = bn;
const int ldo = dk;
/* int qL = */ qL,
/* int kL = */ kL,
int tn = 1;
int tm = (M + bm - 1) / bm;
/* int gqa_factor = */ gqa_factor,
/* float scale = */ scale,
const int batch_stride_q = dk * query_sequence_length;
const int batch_stride_k = dk * query_sequence_length;
const int batch_stride_v = dk * query_sequence_length;
const int batch_stride_o = dk * query_sequence_length;
const int swizzle_log = 0;
const int gemm_n_iterations_aligned = (N + bn - 1) / bn;
const int gemm_k_iterations_aligned = (K + bk - 1) / bk;
const int gemm_sv_m_block_iterations = (M + bm - 1) / bm;
const int batch_ndim = int(batch_shape.size());
/* int NQ = */ NQ,
/* int NK = */ NK,
MLXFastAttentionParams params{
(int)M,
(int)N,
(int)K,
ldq,
ldk,
ldv,
lds,
ldo,
tn,
tm,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_o,
swizzle_log,
gemm_n_iterations_aligned,
gemm_k_iterations_aligned,
gemm_sv_m_block_iterations,
batch_ndim,
alpha};
/* int NQ_aligned = */ NQ_aligned,
/* int NK_aligned = */ NK_aligned,
const std::vector<size_t> batch_strides = {
(size_t)batch_stride_q,
(size_t)batch_stride_k,
(size_t)batch_stride_v,
(size_t)batch_stride_o};
/* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* size_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
compute_encoder.set_input_array(q, 0);
compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(v, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_output_array(o, 3);
compute_encoder.set_bytes(params, 4);
compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7);
MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out);
MTL::Size grid_dims = MTL::Size(NQ, H, B);
MTL::Size group_dims = MTL::Size(32, wm, wn);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -356,7 +325,24 @@ void ScaledDotProductAttention::eval_gpu(
const auto& q = copy_unless(is_matrix_contiguous, q_pre);
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
o.set_data(allocator::malloc_or_wait(o.nbytes()));
size_t str_oD = 1;
size_t str_oH = o.shape(3);
size_t str_oL = o.shape(1) * str_oH;
size_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ 0,
/* bool col_contiguous = */ 0,
};
o.set_data(
allocator::malloc_or_wait(o.nbytes()),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
}

View File

@@ -600,7 +600,7 @@ array scaled_dot_product_attention(
* * dtype is not fp32 or fp16
*/
int threshold = 1e6;
int threshold = 32; // TODO: Fix after dev
if (memory_efficient_threshold.has_value()) {
threshold = std::max(1, memory_efficient_threshold.value());
}
@@ -644,11 +644,10 @@ array scaled_dot_product_attention(
const bool sdpa_vector_supported_head_dim =
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
const bool sdpa_full_supported_head_dim =
query_head_dim == 64 || query_head_dim == 128;
query_head_dim == 64 || query_head_dim == 80;
const bool supports_sdpa_full = query_sequence_length >= threshold &&
!mask.has_value() && sdpa_full_supported_head_dim &&
n_q_heads == n_kv_heads && final_type != bfloat16 &&
stream.device == Device::gpu;
const bool supports_sdpa_vector = query_sequence_length == 1 &&
@@ -686,13 +685,11 @@ array pack_and_quantize(
array& packed_w,
const array& scales,
const array& biases,
int group_size,
int bits,
const Stream& s) {
int el_per_int = 32 / bits;
array zero(0, packed_w.dtype());
array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
packed_w = astype(
clip(
round(divide(subtract(packed_w, biases, s), scales, s), s),
@@ -701,9 +698,30 @@ array pack_and_quantize(
s),
uint32,
s);
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
if (is_power_of_2(bits)) {
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
} else {
// This is slow but we have fast GPU/CPU versions of this function so we
// shouldn't be here often.
packed_w = expand_dims(packed_w, /* axis= */ -1, s);
packed_w = bitwise_and(
right_shift(packed_w, arange(bits, uint32, s), s),
array({1}, uint32),
s);
auto new_shape = packed_w.shape();
new_shape[new_shape.size() - 2] = -1;
new_shape.back() = 32;
packed_w = reshape(packed_w, new_shape, s);
array shifts = arange(32, uint32, s);
packed_w =
sum(left_shift(packed_w, shifts, s),
/* axis= */ -1,
/* keepdims= */ false,
s);
}
return packed_w;
}
@@ -718,10 +736,10 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
throw std::invalid_argument(msg.str());
}
if (bits != 2 && bits != 4 && bits != 8) {
if (bits != 2 && bits != 3 && bits != 4 && bits != 6 && bits != 8) {
std::ostringstream msg;
msg << "[quantize] The requested number of bits " << bits
<< " is not supported. The supported bits are 2, 4 and 8.";
<< " is not supported. The supported bits are 2, 3, 4, 6 and 8.";
throw std::invalid_argument(msg.str());
}
@@ -740,9 +758,7 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
throw std::invalid_argument(msg.str());
}
int el_per_int = 32 / bits;
auto fallback = [group_size, bits, el_per_int, s](
auto fallback = [group_size, bits, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto wshape = w.shape();
@@ -765,7 +781,7 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
array biases = where(equal(q0, zero, s), zero, edge, s);
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
packed_w = pack_and_quantize(packed_w, scales, biases, bits, s);
return {
reshape(packed_w, wshape, s),
reshape(scales, wshape, s),
@@ -774,7 +790,7 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
};
auto wq_shape = w.shape();
wq_shape.back() = w.shape(-1) / el_per_int;
wq_shape.back() = w.shape(-1) * bits / 32;
auto sshape = w.shape();
sshape.back() = w.shape(-1) / group_size;
auto outputs = array::make_arrays(
@@ -785,39 +801,6 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
return {outputs[0], outputs[1], outputs[2]};
}
array affine_quantize(
const array& w,
const array& scales,
const array& biases,
int group_size,
int bits,
StreamOrDevice s_) {
auto s = to_stream(s_);
int el_per_int = 32 / bits;
auto fallback = [group_size, bits, el_per_int, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto scales = expand_dims(inputs[1], -1, s);
auto biases = expand_dims(inputs[2], -1, s);
auto wshape = w.shape();
wshape.back() = -1;
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
return {reshape(packed_w, wshape, s)};
};
auto out_shape = w.shape();
out_shape.back() = w.shape(-1) / el_per_int;
return array(
std::move(out_shape),
uint32,
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
{w, scales, biases});
}
array affine_dequantize(
const array& w,
const array& scales,
@@ -860,9 +843,9 @@ array affine_dequantize(
}
// Packing into uint32
int el_per_int = 32 / bits;
int out_size = w.shape(-1) * 32 / bits;
if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) {
if (out_size != scales.shape(-1) * group_size) {
std::ostringstream msg;
msg << "[dequantize] Shape of scales and biases does not match the matrix "
<< "given the quantization parameters. Provided matrix of shape "
@@ -873,40 +856,52 @@ array affine_dequantize(
auto s = to_stream(s_);
auto fallback =
[&wshape, &sshape, &scales, &biases, group_size, bits, el_per_int, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto fallback = [&wshape, &sshape, &scales, &biases, group_size, bits, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto w = inputs[0];
auto& scales = inputs[1];
auto& biases = inputs[2];
std::vector<array> parts;
for (int start = 0; start < 32; start += bits) {
int shift_left = 32 - (start + bits);
int shift_right = shift_left + start;
if (is_power_of_2(bits)) {
std::vector<array> parts;
for (int start = 0; start < 32; start += bits) {
int shift_left = 32 - (start + bits);
int shift_right = shift_left + start;
parts.push_back(expand_dims(
right_shift(
left_shift(w, array(32 - (start + bits), uint32), s),
array(32 - bits, uint32),
s),
-1,
s));
parts.push_back(expand_dims(
right_shift(
left_shift(w, array(32 - (start + bits), uint32), s),
array(32 - bits, uint32),
s),
-1,
s));
}
w = concatenate(parts, -1, s);
} else {
w = expand_dims(w, /* axis= */ -1, s);
w = bitwise_and(
right_shift(w, arange(32, uint32, s), s), array({1}, uint32), s);
auto new_shape = w.shape();
new_shape[new_shape.size() - 2] = -1;
new_shape.back() = bits;
w = reshape(w, new_shape, s);
array shifts = arange(bits, uint32, s);
w = sum(
left_shift(w, shifts, s), /* axis= */ -1, /* keepdims= */ false, s);
}
array w_full = concatenate(parts, -1, s);
// Dequantize
wshape.push_back(group_size);
w_full = reshape(w_full, wshape, s);
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
w_full = add(w_full, expand_dims(biases, -1, s), s);
w_full = reshape(w_full, sshape, s);
w = reshape(w, wshape, s);
w = multiply(w, expand_dims(scales, -1, s), s);
w = add(w, expand_dims(biases, -1, s), s);
w = reshape(w, sshape, s);
return {w_full};
return {w};
};
if (s.device == Device::gpu) {
auto out_shape = w.shape();
out_shape.back() = w.shape(-1) * el_per_int;
out_shape.back() = out_size;
return array(
std::move(out_shape),
scales.dtype(),

View File

@@ -47,14 +47,6 @@ std::tuple<array, array, array> affine_quantize(
int bits = 4,
StreamOrDevice s = {});
array affine_quantize(
const array& w,
const array& scales,
const array& biases,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});
array affine_dequantize(
const array& w,
const array& scales,

View File

@@ -3683,7 +3683,7 @@ std::tuple<array, array, array> quantize(
int group_size /* = 64 */,
int bits /* = 4 */,
StreamOrDevice s /* = {} */) {
return fast::affine_quantize(w, group_size, bits);
return fast::affine_quantize(w, group_size, bits, s);
}
array dequantize(

View File

@@ -185,16 +185,8 @@ class _Pool3d(_Pool):
class MaxPool1d(_Pool1d):
r"""Applies 1-dimensional max pooling.
Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is
:math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given
by:
.. math::
\text{out}(N_i, t, C_j) = \max_{m=0, \ldots, k - 1}
\text{input}(N_i, \text{stride} \times t + m, C_j),
where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
\text{kernel\_size}}{\text{stride}}\right\rfloor + 1`.
Spatially downsamples the input by taking the maximum of a sliding window
of size ``kernel_size`` and sliding stride ``stride``.
Args:
kernel_size (int or tuple(int)): The size of the pooling window kernel.
@@ -224,16 +216,8 @@ class MaxPool1d(_Pool1d):
class AvgPool1d(_Pool1d):
r"""Applies 1-dimensional average pooling.
Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is
:math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given
by:
.. math::
\text{out}(N_i, t, C_j) = \frac{1}{k} \sum_{m=0, \ldots, k - 1}
\text{input}(N_i, \text{stride} \times t + m, C_j),
where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
\text{kernel\_size}}{\text{stride}}\right\rfloor + 1`.
Spatially downsamples the input by taking the average of a sliding window
of size ``kernel_size`` and sliding stride ``stride``.
Args:
kernel_size (int or tuple(int)): The size of the pooling window kernel.
@@ -263,26 +247,15 @@ class AvgPool1d(_Pool1d):
class MaxPool2d(_Pool2d):
r"""Applies 2-dimensional max pooling.
Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is
:math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out},
W_{out}, C)`, given by:
Spatially downsamples the input by taking the maximum of a sliding window
of size ``kernel_size`` and sliding stride ``stride``.
.. math::
\begin{aligned}
\text{out}(N_i, h, w, C_j) = & \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\
& \text{input}(N_i, \text{stride[0]} \times h + m,
\text{stride[1]} \times w + n, C_j),
\end{aligned}
The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
- a single ``int`` -- in which case the same value is used for both the
height and width axis;
- a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
used for the height axis, the second ``int`` for the width axis.
* a single ``int`` -- in which case the same value is used for both the
height and width axis.
* a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
used for the height axis, the second ``int`` for the width axis.
Args:
kernel_size (int or tuple(int, int)): The size of the pooling window.
@@ -312,26 +285,15 @@ class MaxPool2d(_Pool2d):
class AvgPool2d(_Pool2d):
r"""Applies 2-dimensional average pooling.
Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is
:math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out},
W_{out}, C)`, given by:
Spatially downsamples the input by taking the average of a sliding window
of size ``kernel_size`` and sliding stride ``stride``.
.. math::
\begin{aligned}
\text{out}(N_i, h, w, C_j) = & \frac{1}{k_H k_W} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\
& \text{input}(N_i, \text{stride[0]} \times h + m,
\text{stride[1]} \times w + n, C_j),
\end{aligned}
The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
- a single ``int`` -- in which case the same value is used for both the
height and width axis;
- a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
used for the height axis, the second ``int`` for the width axis.
* a single ``int`` -- in which case the same value is used for both the
height and width axis.
* a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
used for the height axis, the second ``int`` for the width axis.
Args:
kernel_size (int or tuple(int, int)): The size of the pooling window.
@@ -359,30 +321,18 @@ class AvgPool2d(_Pool2d):
class MaxPool3d(_Pool3d):
"""
Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
:math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
H_{out}, W_{out}, C)`, given by:
r"""Applies 3-dimensional max pooling.
.. math::
\begin{aligned}
\text{out}(N_i, d, h, w, C_j) = & \max_{l=0, \ldots, k_D-1} \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\
& \text{input}(N_i, \text{stride[0]} \times d + l,
\text{stride[1]} \times h + m,
\text{stride[2]} \times w + n, C_j),
\end{aligned}
Spatially downsamples the input by taking the maximum of a sliding window
of size ``kernel_size`` and sliding stride ``stride``.
where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
:math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`,
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`.
The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
- a single ``int`` -- in which case the same value is used for the depth,
height and width axis;
- a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
for the depth axis, the second ``int`` for the height axis, and the third
``int`` for the width axis.
* a single ``int`` -- in which case the same value is used for the depth,
height, and width axis.
* a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
for the depth axis, the second ``int`` for the height axis, and the third
``int`` for the width axis.
Args:
kernel_size (int or tuple(int, int, int)): The size of the pooling window.
@@ -410,32 +360,20 @@ class MaxPool3d(_Pool3d):
class AvgPool3d(_Pool3d):
"""
Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
:math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
H_{out}, W_{out}, C)`, given by:
r"""Applies 3-dimensional average pooling.
.. math::
\begin{aligned}
\text{out}(N_i, d, h, w, C_j) = & \frac{1}{k_D k_H k_W} \sum_{l=0, \ldots, k_D-1} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\
& \text{input}(N_i, \text{stride[0]} \times d + l,
\text{stride[1]} \times h + m,
\text{stride[2]} \times w + n, C_j),
\end{aligned}
Spatially downsamples the input by taking the average of a sliding window
of size ``kernel_size`` and sliding stride ``stride``.
where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
:math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`,
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`.
The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
* a single ``int`` -- in which case the same value is used for the depth,
height, and width axis.
* a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
for the depth axis, the second ``int`` for the height axis, and the third
``int`` for the width axis.
- a single ``int`` -- in which case the same value is used for the depth,
height and width axis;
- a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
for the depth axis, the second ``int`` for the height axis, and the third
``int`` for the width axis.
Args:
Args:
kernel_size (int or tuple(int, int, int)): The size of the pooling window.
stride (int or tuple(int, int, int), optional): The stride of the pooling
window. Default: ``kernel_size``.
@@ -443,7 +381,7 @@ class AvgPool3d(_Pool3d):
padding to apply to the input. The padding is applied on both sides
of the depth, height and width axis. Default: ``0``.
Examples:
Examples:
>>> import mlx.core as mx
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))

View File

@@ -161,49 +161,6 @@ void init_fast(nb::module_& parent_module) {
array: The output array.
)pbdoc");
m.def(
"affine_quantize",
nb::overload_cast<
const array&,
const array&,
const array&,
int,
int,
StreamOrDevice>(&fast::affine_quantize),
"w"_a,
"scales"_a,
"biases"_a,
"group_size"_a = 64,
"bits"_a = 4,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def affine_quantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Quantize the matrix ``w`` using the provided ``scales`` and
``biases`` and the ``group_size`` and ``bits`` configuration.
Formally, given the notation in :func:`quantize`, we compute
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
:math:`\beta` as follows
.. math::
w_i = s (\hat{w_i} + \beta)
Args:
w (array): Matrix to be quantize
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``)
Returns:
array: The quantized version of ``w``
)pbdoc");
m.def(
"metal_kernel",
[](const std::string& name,

View File

@@ -549,18 +549,6 @@ class TestFast(mlx_tests.MLXTestCase):
)(x)
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
def test_affine_quantize(self):
mx.random.seed(7)
x = mx.random.uniform(shape=(4, 1024))
for bits in (2, 4, 8):
for group_size in (32, 64, 128):
with self.subTest(bits=bits, group_size=group_size):
w, scales, biases = mx.quantize(x, bits=bits, group_size=group_size)
w_p = mx.fast.affine_quantize(
x, scales, biases, bits=bits, group_size=group_size
)
self.assertTrue(mx.allclose(w, w_p))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_basic(self):
mx.random.seed(7)

View File

@@ -11,7 +11,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
def test_quantize_dequantize(self):
w = mx.random.normal(shape=(128, 512))
for gs in [32, 64, 128]:
for b in [2, 4, 8]:
for b in [2, 3, 6, 4, 8]:
with self.subTest(gs=gs, b=b):
w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b)
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
@@ -22,7 +22,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
# test quantize/dequantize 0s
a = mx.zeros((256, 512))
for gs in [32, 64, 128]:
for b in [2, 4, 8]:
for b in [2, 3, 4, 6, 8]:
w_q, scales, biases = mx.quantize(a, gs, b)
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
self.assertTrue(mx.all(a_hat == 0))
@@ -116,7 +116,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 4, 8], # bits
[2, 3, 4, 6, 8], # bits
[512, 1024, 67], # M
[64, 128, 512, 1024], # N
[0, 1, 3, 8], # B
@@ -143,7 +143,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 4, 8], # bits
[2, 3, 4, 6, 8], # bits
[512, 1024], # M
[512, 1024, 67], # N
[0, 1, 3, 8], # B

View File

@@ -165,7 +165,7 @@ if __name__ == "__main__":
setup(
name="mlx",
version=get_version("0.20.0"),
version=get_version("0.21.0"),
author="MLX Contributors",
author_email="mlx@group.apple.com",
description="A framework for machine learning on Apple silicon.",