mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
2-Pass Sdpa Inference Kernel (#1597)
This commit is contained in:
parent
9bd03dd9b4
commit
073076ac7d
@ -4,42 +4,51 @@ import math
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
L = 1024
|
||||
L = 16384
|
||||
H = 32
|
||||
H_k = 32 // 4
|
||||
H_k = H // 4
|
||||
D = 128
|
||||
dtype = mx.float16
|
||||
loops = 10
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, D)
|
||||
def _sdpa(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, D)
|
||||
|
||||
for i in range(loops):
|
||||
q = _sdpa(q, k, v)
|
||||
return q
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
for i in range(loops):
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
return q
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
mx.eval(q, k, v)
|
||||
time_fn(attention, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
mx.eval(q, k, v)
|
||||
time_fn(sdpa, q, k, v)
|
||||
|
||||
|
@ -926,21 +926,10 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2);
|
||||
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
|
||||
|
||||
// SDPA vector instantiations
|
||||
#define instantiate_sdpa_vector(type, head_dim) \
|
||||
template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \
|
||||
[[kernel]] void sdpa_vector<type, head_dim>( \
|
||||
const device type* queries [[buffer(0)]], \
|
||||
const device type* keys [[buffer(1)]], \
|
||||
const device type* values [[buffer(2)]], \
|
||||
device type* out [[buffer(3)]], \
|
||||
const constant int& gqa_factor, \
|
||||
const constant int& N, \
|
||||
const constant size_t& k_stride, \
|
||||
const constant size_t& v_stride, \
|
||||
const constant float& scale, \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
#define instantiate_sdpa_vector(type, head_dim) \
|
||||
instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) \
|
||||
instantiate_kernel("sdpa_vector_2pass_1_" #type "_" #head_dim, sdpa_vector_2pass_1, type, head_dim) \
|
||||
instantiate_kernel("sdpa_vector_2pass_2_" #type "_" #head_dim, sdpa_vector_2pass_2, type, head_dim)
|
||||
|
||||
#define instantiate_sdpa_vector_heads(type) \
|
||||
instantiate_sdpa_vector(type, 64) \
|
||||
|
@ -21,8 +21,7 @@ template <typename T, int D>
|
||||
constexpr int BN = 32;
|
||||
constexpr int BD = 32;
|
||||
constexpr int elem_per_thread = D / BD;
|
||||
|
||||
const int stride = BN * D;
|
||||
constexpr int stride = BN * D;
|
||||
|
||||
typedef float U;
|
||||
|
||||
@ -84,7 +83,6 @@ template <typename T, int D>
|
||||
keys += stride;
|
||||
values += stride;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Each thread has a partial part of the output so we need to combine them.
|
||||
|
||||
@ -114,3 +112,181 @@ template <typename T, int D>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int D>
|
||||
[[kernel]] void sdpa_vector_2pass_1(
|
||||
const device T* queries [[buffer(0)]],
|
||||
const device T* keys [[buffer(1)]],
|
||||
const device T* values [[buffer(2)]],
|
||||
device float* out [[buffer(3)]],
|
||||
device float* sums [[buffer(4)]],
|
||||
device float* maxs [[buffer(5)]],
|
||||
const constant int& gqa_factor,
|
||||
const constant int& N,
|
||||
const constant size_t& k_stride,
|
||||
const constant size_t& v_stride,
|
||||
const constant float& scale,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int BN = 8;
|
||||
constexpr int BD = 32;
|
||||
constexpr int elem_per_thread = D / BD;
|
||||
constexpr int stride = BN * D;
|
||||
constexpr int blocks = 32;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U q[elem_per_thread];
|
||||
thread U k[elem_per_thread];
|
||||
thread U o[elem_per_thread];
|
||||
|
||||
threadgroup U outputs[BN * BD];
|
||||
threadgroup U max_scores[BN];
|
||||
threadgroup U sum_exp_scores[BN];
|
||||
|
||||
// Adjust positions
|
||||
const int block_idx = tid.z;
|
||||
const int head_idx = tid.y;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
queries += head_idx * D + simd_lid * elem_per_thread;
|
||||
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
|
||||
simd_lid * elem_per_thread;
|
||||
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
|
||||
simd_lid * elem_per_thread;
|
||||
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
|
||||
sums += head_idx * blocks + block_idx;
|
||||
maxs += head_idx * blocks + block_idx;
|
||||
|
||||
// Read the query and 0 the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
q[i] = static_cast<U>(scale) * queries[i];
|
||||
}
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = 0;
|
||||
}
|
||||
|
||||
U max_score = -1e9;
|
||||
U sum_exp_score = 0;
|
||||
|
||||
// For each key
|
||||
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
||||
// Read the key
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
k[i] = keys[i];
|
||||
}
|
||||
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
|
||||
// Update the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = o[i] * factor + exp_score * values[i];
|
||||
}
|
||||
|
||||
// Move the pointers to the next kv
|
||||
keys += blocks * stride;
|
||||
values += blocks * stride;
|
||||
}
|
||||
|
||||
// Each thread has a partial part of the output so we need to combine them.
|
||||
|
||||
// First let's communicate the max and sum_exp
|
||||
if (simd_lid == 0) {
|
||||
max_scores[simd_gid] = max_score;
|
||||
sum_exp_scores[simd_gid] = sum_exp_score;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
|
||||
U new_max = simd_max(max_score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
|
||||
sum_exp_score = simd_sum(sum_exp_score * factor);
|
||||
|
||||
// Write the sum and new max
|
||||
if (simd_gid == 0) {
|
||||
sums[0] = sum_exp_score;
|
||||
maxs[0] = new_max;
|
||||
}
|
||||
|
||||
// Now we need to aggregate all the outputs
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
outputs[simd_lid * BN + simd_gid] =
|
||||
o[i] * fast::exp(max_scores[simd_gid] - new_max);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// And write the output
|
||||
if (simd_gid == 0) {
|
||||
U output = outputs[simd_lid * BN];
|
||||
for (int j = 1; j < BN; j++) {
|
||||
output += outputs[simd_lid * BN + j];
|
||||
}
|
||||
out[i] = static_cast<T>(output);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int D>
|
||||
[[kernel]] void sdpa_vector_2pass_2(
|
||||
const device float* partials [[buffer(0)]],
|
||||
const device float* sums [[buffer(1)]],
|
||||
const device float* maxs [[buffer(2)]],
|
||||
device T* out [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int BN = 32;
|
||||
constexpr int BD = 32;
|
||||
constexpr int elem_per_thread = D / BD;
|
||||
constexpr int blocks = 32;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U o[elem_per_thread];
|
||||
threadgroup U outputs[BN * BD];
|
||||
|
||||
// Adjust positions
|
||||
const int head_idx = tid.y;
|
||||
partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
|
||||
sums += head_idx * blocks;
|
||||
maxs += head_idx * blocks;
|
||||
out += head_idx * D + simd_gid * elem_per_thread;
|
||||
|
||||
// First everybody reads the max and sum_exp
|
||||
U max_score = maxs[simd_lid];
|
||||
U new_max = simd_max(max_score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U sum_exp_score = simd_sum(sums[simd_lid] * factor);
|
||||
|
||||
// Now read the block into registers and then use shared memory to transpose
|
||||
// it
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = partials[i];
|
||||
}
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
outputs[simd_lid * BD + simd_gid] = o[i];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
// And write the output
|
||||
if (simd_lid == 0) {
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
out[i] = static_cast<T>(o[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
@ -184,6 +185,94 @@ void sdpa_vector(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void sdpa_vector_2pass(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
array& out,
|
||||
float scale) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
kname += "sdpa_vector_2pass_1_";
|
||||
kname += get_type_string(q.dtype());
|
||||
kname += "_";
|
||||
kname += std::to_string(q.shape(-1));
|
||||
|
||||
// Compute the necessary sizes
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
int N = k.shape(2);
|
||||
int blocks = 32;
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t k_stride = k.strides()[1];
|
||||
size_t v_stride = v.strides()[1];
|
||||
MTL::Size group_dims(8 * 32, 1, 1);
|
||||
MTL::Size grid_dims(1, B, blocks);
|
||||
|
||||
// Allocate the intermediates
|
||||
std::vector<int> intermediate_shape;
|
||||
intermediate_shape.reserve(out.ndim() + 1);
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1);
|
||||
intermediate_shape.push_back(blocks);
|
||||
intermediate_shape.push_back(out.shape().back());
|
||||
array intermediate(intermediate_shape, float32, nullptr, {});
|
||||
intermediate_shape.pop_back();
|
||||
array sums(intermediate_shape, float32, nullptr, {});
|
||||
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
sums.set_data(allocator::malloc_or_wait(sums.nbytes()));
|
||||
maxs.set_data(allocator::malloc_or_wait(maxs.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
d.add_temporary(sums, s.index);
|
||||
d.add_temporary(maxs, s.index);
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder.set_output_array(intermediate, 3);
|
||||
compute_encoder.set_output_array(sums, 4);
|
||||
compute_encoder.set_output_array(maxs, 5);
|
||||
compute_encoder.set_bytes(gqa_factor, 6);
|
||||
compute_encoder.set_bytes(N, 7);
|
||||
compute_encoder.set_bytes(k_stride, 8);
|
||||
compute_encoder.set_bytes(v_stride, 9);
|
||||
compute_encoder.set_bytes(scale, 10);
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Final pass
|
||||
kname.clear();
|
||||
kname += "sdpa_vector_2pass_2_";
|
||||
kname += get_type_string(q.dtype());
|
||||
kname += "_";
|
||||
kname += std::to_string(q.shape(-1));
|
||||
|
||||
// Get the kernel
|
||||
kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
compute_encoder.set_input_array(sums, 1);
|
||||
compute_encoder.set_input_array(maxs, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
// Launch
|
||||
group_dims = MTL::Size(1024, 1, 1);
|
||||
grid_dims = MTL::Size(1, B, 1);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
@ -249,7 +338,17 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
} else {
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
}
|
||||
sdpa_vector(s, d, q, k, v, o, scale_);
|
||||
|
||||
// We route to the 2 pass fused attention if
|
||||
// - The device is large and the sequence length long
|
||||
// - The sequence length is even longer and we have gqa
|
||||
char devc = d.get_architecture().back();
|
||||
if ((devc == 'd' && k.shape(2) >= 1024) ||
|
||||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
|
||||
sdpa_vector_2pass(s, d, q, k, v, o, scale_);
|
||||
} else {
|
||||
sdpa_vector(s, d, q, k, v, o, scale_);
|
||||
}
|
||||
}
|
||||
|
||||
// Full attention mode
|
||||
|
Loading…
Reference in New Issue
Block a user