mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fast Inference SDPA op (#735)
* Fast Inference SDPA op Implements metal shaders for: o = mx.fast_inference_sdpa(queries, keys, values, scale, mask) Supports fp16, fp32 dtypes; assumes d_k = 128. Generic op support / prompt encoding supported via mlx primitives. Metal implementation is for the inference use case only. Majority of performance benefits appears to results from GQA & reduced bandwidth requirements; there is approximate performance parity for the MHA use case (from some measurements on M3 Max). * Flush shared memory to zero before unprotected reads for (scores @ values) * Move to fast:: namespace, address reviewer comments ... also attempt to revert formatter auto-change for files not relevant to this change * Shared memory flush to top of kernel * Resolve compiler warnings * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update docstring per PR feedback * Softmax in higher precision, ... * route to fallback for more use cases - batch size > 1, head_dim other than 128, etc. * Address linux build failure * Address other reviewer comments * Remove extraneous eval_cpu function per review --------- Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com> Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: atila <atiorh@icloud.com>
This commit is contained in:
parent
7b463ffb07
commit
0787724c44
@ -13,7 +13,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||
|
||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
</a>
|
||||
|
@ -29,6 +29,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
|
@ -25,6 +25,7 @@ set(
|
||||
"random"
|
||||
"rope"
|
||||
"scan"
|
||||
"scaled_dot_product_attention"
|
||||
"softmax"
|
||||
"sort"
|
||||
"ternary"
|
||||
|
451
mlx/backend/metal/kernels/scaled_dot_product_attention.metal
Normal file
451
mlx/backend/metal/kernels/scaled_dot_product_attention.metal
Normal file
@ -0,0 +1,451 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||
using namespace metal;
|
||||
|
||||
template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_t NSIMDGROUPS>
|
||||
[[kernel]] void fast_inference_sdpa_compute_partials_template(const device T *Q [[buffer(0)]],
|
||||
const device T *K [[buffer(1)]],
|
||||
const device T *V [[buffer(2)]],
|
||||
const device uint64_t& L [[buffer(3)]],
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
|
||||
device float* O_partials [[buffer(5)]],
|
||||
device float* p_lse [[buffer(6)]],
|
||||
device float* p_maxes [[buffer(7)]],
|
||||
threadgroup T* threadgroup_block [[threadgroup(0)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
constexpr const size_t DK = 128;
|
||||
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
|
||||
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
|
||||
constexpr const uint iter_offset = NSIMDGROUPS * 4;
|
||||
const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
|
||||
uint kv_head_offset_factor = tid.x;
|
||||
if(is_gqa) {
|
||||
int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
|
||||
kv_head_offset_factor = tid.x / q_kv_head_ratio;
|
||||
}
|
||||
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
|
||||
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP = TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
|
||||
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR * SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) * NSIMDGROUPS;
|
||||
|
||||
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
|
||||
#pragma clang loop unroll(full)
|
||||
for(uint i = 0; i < 8; i++) {
|
||||
smemFlush[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP + i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// TODO: multiple query sequence length for speculative decoding
|
||||
const uint tgroup_query_head_offset = tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
|
||||
|
||||
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
|
||||
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
|
||||
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
|
||||
|
||||
const device T* baseK = K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
|
||||
const device T* baseQ = Q + tgroup_query_head_offset;
|
||||
|
||||
device T4* simdgroupQueryData = (device T4*)baseQ;
|
||||
|
||||
constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS;
|
||||
float threadAccum[ACCUM_PER_GROUP];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; threadAccumIndex++) {
|
||||
threadAccum[threadAccumIndex] = -INFINITY;
|
||||
}
|
||||
|
||||
uint KROW_ACCUM_INDEX = 0;
|
||||
|
||||
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
|
||||
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
|
||||
const bool LAST_TILE_ALIGNED = (SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
|
||||
|
||||
T4 thread_data_x4;
|
||||
T4 thread_data_y4;
|
||||
if(!LAST_TILE || LAST_TILE_ALIGNED) {
|
||||
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
||||
#pragma clang loop unroll(full)
|
||||
for(size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST; KROW += NSIMDGROUPS) {
|
||||
const uint KROW_OFFSET = KROW * DK;
|
||||
const device T* baseKRow = baseK + KROW_OFFSET;
|
||||
device T4* keysData = (device T4*)baseKRow;
|
||||
thread_data_y4 = *(keysData + simd_lane_id);
|
||||
T kq_scalar = dot(thread_data_x4, thread_data_y4);
|
||||
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
|
||||
KROW_ACCUM_INDEX++;
|
||||
}
|
||||
} else {
|
||||
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
||||
const uint START_ROW = tid.y * TILE_SIZE_CONST;
|
||||
const device T* baseKThisHead = K + tgroup_k_batch_offset + tgroup_k_head_offset;
|
||||
|
||||
for(size_t KROW = START_ROW + simd_group_id; KROW < L; KROW += NSIMDGROUPS) {
|
||||
const uint KROW_OFFSET = KROW * DK;
|
||||
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
|
||||
device T4* keysData = (device T4*)baseKRow;
|
||||
thread_data_y4 = *(keysData + simd_lane_id);
|
||||
T kq_scalar = dot(thread_data_x4, thread_data_y4);
|
||||
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
|
||||
KROW_ACCUM_INDEX++;
|
||||
}
|
||||
}
|
||||
threadgroup float* smemP = (threadgroup float*)threadgroup_block;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(size_t i = 0; i < P_VEC4; i++) {
|
||||
thread_data_x4 = T4(threadAccum[4 * i], threadAccum[4 * i + 1], threadAccum[4 * i + 2], threadAccum[4 * i + 3]);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
thread_data_y4 = simd_sum(thread_data_x4);
|
||||
if(simd_lane_id == 0) {
|
||||
const uint base_smem_p_offset = i * iter_offset + simd_group_id;
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z);
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w);
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float groupMax;
|
||||
float lse = 0.f;
|
||||
|
||||
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
|
||||
constexpr const size_t ACCUM_ARRAY_LENGTH = TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
|
||||
float4 pvals[ACCUM_ARRAY_LENGTH];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH; accum_array_iter++) {
|
||||
pvals[accum_array_iter] = float4(-INFINITY);
|
||||
}
|
||||
|
||||
if (TILE_SIZE_CONST == 64) {
|
||||
threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block;
|
||||
float2 vals = smemPtrFlt2[simd_lane_id];
|
||||
vals *= params.INV_ALPHA;
|
||||
float maxval = max(vals.x, vals.y);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
groupMax = simd_max(maxval);
|
||||
|
||||
float2 expf_shifted = exp(vals - groupMax);
|
||||
float sumExpLocal = expf_shifted.x + expf_shifted.y;
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
float tgroupExpSum = simd_sum(sumExpLocal);
|
||||
|
||||
lse = log(tgroupExpSum);
|
||||
float2 local_p_hat = expf_shifted / tgroupExpSum;
|
||||
pvals[0].x = local_p_hat.x;
|
||||
pvals[0].y = local_p_hat.y;
|
||||
smemPtrFlt2[simd_lane_id] = float2(0.f);
|
||||
}
|
||||
constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64;
|
||||
constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128;
|
||||
|
||||
if (TILE_SIZE_LARGER_THAN_64) {
|
||||
float maxval = -INFINITY;
|
||||
threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
||||
float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
|
||||
vals *= params.INV_ALPHA;
|
||||
pvals[i] = vals;
|
||||
maxval = fmax3(vals.x, vals.y, maxval);
|
||||
maxval = fmax3(vals.z, vals.w, maxval);
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
groupMax = simd_max(maxval);
|
||||
|
||||
float sumExpLocal = 0.f;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
||||
pvals[i] = exp(pvals[i] - groupMax);
|
||||
sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
float tgroupExpSum = simd_sum(sumExpLocal);
|
||||
lse = log(tgroupExpSum);
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
||||
pvals[i] = pvals[i] / tgroupExpSum;
|
||||
smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup T* smemV = (threadgroup T*)threadgroup_block;
|
||||
|
||||
const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK;
|
||||
const size_t v_head_offset = kv_head_offset_factor * L * DK;
|
||||
|
||||
const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK;
|
||||
const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset;
|
||||
device T* baseV = (device T*)V + v_offset;
|
||||
|
||||
threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
|
||||
|
||||
if (!LAST_TILE || LAST_TILE_ALIGNED) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(size_t col = 0; col < MATRIX_COLS; col++) {
|
||||
uint matrix_load_loop_iter = 0;
|
||||
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
|
||||
|
||||
for(size_t tile_start = simd_group_id; tile_start < TILE_SIZE_CONST_DIV_8; tile_start += NSIMDGROUPS) {
|
||||
simdgroup_matrix<T, 8, 8> tmp;
|
||||
ulong simdgroup_matrix_offset = matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
|
||||
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
|
||||
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
|
||||
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
||||
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false);
|
||||
matrix_load_loop_iter++;
|
||||
};
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (TILE_SIZE_CONST == 64) {
|
||||
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
||||
uint loop_iter = 0;
|
||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
||||
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
||||
T2 v_local = *(smemV2 + simd_lane_id);
|
||||
|
||||
T val = dot(local_p_hat, v_local);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
T row_sum = simd_sum(val);
|
||||
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
|
||||
loop_iter++;
|
||||
}
|
||||
}
|
||||
|
||||
if (TILE_SIZE_CONST > 64) {
|
||||
constexpr const size_t TILE_SIZE_CONST_DIV_128 = (TILE_SIZE_CONST + 1) / 128;
|
||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
uint loop_iter = 0;
|
||||
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
||||
|
||||
T row_sum = 0.f;
|
||||
for(size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
|
||||
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
|
||||
T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
|
||||
T4 p_local = T4(pvals[i]);
|
||||
T val = dot(p_local, v_local);
|
||||
row_sum += val;
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
row_sum = simd_sum(row_sum);
|
||||
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
|
||||
loop_iter++;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int32_t START_ROW = tid.y * TILE_SIZE_CONST;
|
||||
const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
|
||||
const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
|
||||
constexpr const int ROWS_PER_ITER = 8;
|
||||
#pragma clang loop unroll(full)
|
||||
for(size_t col = 0; col < MATRIX_COLS; col++) {
|
||||
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
int32_t tile_start;
|
||||
for(tile_start = START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; tile_start < MAX_START_ROW; tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
||||
simdgroup_matrix<T, 8, 8> tmp;
|
||||
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
|
||||
simdgroup_load(tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
|
||||
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
|
||||
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
||||
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, /* transpose */ false);
|
||||
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
};
|
||||
|
||||
tile_start = ((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
|
||||
|
||||
const int32_t INT_L = int32_t(L);
|
||||
for(int row_index = tile_start + simd_group_id ; row_index < INT_L; row_index += NSIMDGROUPS) {
|
||||
if(simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
||||
const uint elems_per_row_gmem = DK;
|
||||
const uint col_index_v_gmem = col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
|
||||
const uint row_index_v_gmem = row_index;
|
||||
|
||||
const uint elems_per_row_smem = TILE_SIZE_CONST;
|
||||
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
|
||||
const uint row_index_v_smem = simd_lane_id;
|
||||
|
||||
const uint scalar_offset_gmem = row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
|
||||
const uint scalar_offset_smem = row_index_v_smem * elems_per_row_smem + col_index_v_smem;
|
||||
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
|
||||
smemV[scalar_offset_smem] = vdata;
|
||||
smem_col_index += NSIMDGROUPS;
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (TILE_SIZE_CONST == 64) {
|
||||
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
for(size_t smem_row_index = simd_group_id;
|
||||
smem_row_index < ROWS_PER_ITER; smem_row_index += NSIMDGROUPS) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
|
||||
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
||||
T2 v_local = *(smemV2 + simd_lane_id);
|
||||
T val = dot(local_p_hat, v_local);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
T row_sum = simd_sum(val);
|
||||
oPartialSmem[smem_row_index] = float(row_sum);
|
||||
}
|
||||
}
|
||||
|
||||
if (TILE_SIZE_CONST > 64) {
|
||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
uint loop_count = 0;
|
||||
for(size_t row_index = simd_group_id;
|
||||
row_index < ROWS_PER_ITER; row_index += NSIMDGROUPS) {
|
||||
T row_sum = 0.f;
|
||||
for(size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128; tile_iters++) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
|
||||
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
|
||||
T4 v_local = *(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
|
||||
T4 p_local = T4(pvals[tile_iters]);
|
||||
row_sum += dot(p_local, v_local);
|
||||
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
row_sum = simd_sum(row_sum);
|
||||
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] = float(row_sum);
|
||||
loop_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if(simd_group_id == 0) {
|
||||
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
|
||||
float4 vals = *(oPartialVec4 + simd_lane_id);
|
||||
device float* oPartialGmem = O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
|
||||
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
|
||||
oPartialGmemVec4[simd_lane_id] = vals;
|
||||
}
|
||||
|
||||
if(simd_group_id == 0 && simd_lane_id == 0) {
|
||||
const uint tileIndex = tid.y;
|
||||
const uint gmem_partial_scalar_offset = tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES + tileIndex;
|
||||
p_lse[gmem_partial_scalar_offset] = lse;
|
||||
p_maxes[gmem_partial_scalar_offset] = groupMax;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, nsimdgroups) \
|
||||
template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_size "_" #nsimdgroups )]] \
|
||||
[[kernel]] void fast_inference_sdpa_compute_partials_template<itype, itype2, itype4, tile_size, nsimdgroups>( \
|
||||
const device itype *Q [[buffer(0)]], \
|
||||
const device itype *K [[buffer(1)]], \
|
||||
const device itype *V [[buffer(2)]], \
|
||||
const device uint64_t& L [[buffer(3)]], \
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
|
||||
device float* O_partials [[buffer(5)]], \
|
||||
device float* p_lse [[buffer(6)]], \
|
||||
device float* p_maxes [[buffer(7)]], \
|
||||
threadgroup itype *threadgroup_block [[threadgroup(0)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]]);
|
||||
|
||||
|
||||
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper(itype, itype2, itype4, tile_size) \
|
||||
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 4) \
|
||||
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 8) \
|
||||
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 64);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 128);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 256);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 512);
|
||||
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 64);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 128);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 256);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 512);
|
||||
|
||||
|
||||
template <typename T>
|
||||
void fast_inference_sdpa_reduce_tiles_template(
|
||||
const device float *O_partials [[buffer(0)]],
|
||||
const device float *p_lse[[buffer(1)]],
|
||||
const device float *p_maxes [[buffer(2)]],
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
||||
device T* O [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
constexpr const int DK = 128;
|
||||
const ulong offset_rows = tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
|
||||
const device float* p_lse_row = p_lse + offset_rows;
|
||||
const device float* p_rowmax_row = p_maxes + offset_rows;
|
||||
// reserve some number of registers. this constitutes an assumption on max value of KV TILES.
|
||||
constexpr const uint8_t reserve = 128;
|
||||
float p_lse_regs[reserve];
|
||||
float p_rowmax_regs[reserve];
|
||||
float weights[reserve];
|
||||
|
||||
float true_max = -INFINITY;
|
||||
for(size_t i = 0; i < params.KV_TILES; i++) {
|
||||
p_lse_regs[i] = float(*(p_lse_row + i));
|
||||
p_rowmax_regs[i] = float(*(p_rowmax_row + i));
|
||||
true_max = fmax(p_rowmax_regs[i], true_max);
|
||||
weights[i] = exp(p_lse_regs[i]);
|
||||
}
|
||||
|
||||
float denom = 0.f;
|
||||
for(size_t i = 0; i < params.KV_TILES; i++) {
|
||||
weights[i] *= exp(p_rowmax_regs[i]-true_max);
|
||||
denom += weights[i];
|
||||
}
|
||||
|
||||
const device float* O_partials_with_offset = O_partials + tid.z * params.N_Q_HEADS * DK * params.KV_TILES + tid.x * DK * params.KV_TILES;
|
||||
|
||||
float o_value = 0.f;
|
||||
for(size_t i = 0; i < params.KV_TILES; i++) {
|
||||
float val = *(O_partials_with_offset + i * DK + lid.x);
|
||||
o_value += val * weights[i] / denom;
|
||||
}
|
||||
device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK;
|
||||
O_gmem[lid.x] = T(o_value);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
kernel void fast_inference_sdpa_reduce_tiles_float(
|
||||
const device float *O_partials [[buffer(0)]],
|
||||
const device float *p_lse[[buffer(1)]],
|
||||
const device float *p_maxes [[buffer(2)]],
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
||||
device float* O [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]])
|
||||
{
|
||||
fast_inference_sdpa_reduce_tiles_template<float>(O_partials, p_lse, p_maxes, params,
|
||||
O, tid, lid);
|
||||
}
|
||||
|
||||
kernel void fast_inference_sdpa_reduce_tiles_half(
|
||||
const device float *O_partials [[buffer(0)]],
|
||||
const device float *p_lse[[buffer(1)]],
|
||||
const device float *p_maxes [[buffer(2)]],
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
||||
device half* O [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]])
|
||||
{
|
||||
fast_inference_sdpa_reduce_tiles_template<half>(O_partials, p_lse, p_maxes, params,
|
||||
O, tid, lid);
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
//
|
||||
// scaled_dot_product_attention_params.h
|
||||
// mlx
|
||||
|
||||
#pragma once
|
||||
|
||||
struct MLXScaledDotProductAttentionParams {
|
||||
// Associated dimensions & transposition information
|
||||
const uint QUERY_SEQUENCE_LENGTH = 1;
|
||||
const uint N_Q_HEADS = 32;
|
||||
const uint N_KV_HEADS = 32;
|
||||
const uint KV_TILES = 1;
|
||||
const float INV_ALPHA = 0.08838834764831843f;
|
||||
};
|
222
mlx/backend/metal/scaled_dot_product_attention.cpp
Normal file
222
mlx/backend/metal/scaled_dot_product_attention.cpp
Normal file
@ -0,0 +1,222 @@
|
||||
//
|
||||
// scaled_dot_product_attention.cpp
|
||||
// mlx
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
namespace {
|
||||
|
||||
void sdpa_metal(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
const array& p_lse,
|
||||
const array& p_rowmaxes,
|
||||
const array& o_partial,
|
||||
const uint heads,
|
||||
const uint tile_size,
|
||||
const uint n_tiles,
|
||||
const float alpha,
|
||||
array& out,
|
||||
std::vector<array>& temporaries) {
|
||||
std::ostringstream kname_partials;
|
||||
|
||||
kname_partials << "fast_inference_sdpa_compute_partials_";
|
||||
|
||||
std::ostringstream kname_reduce;
|
||||
std::string delimiter = "_";
|
||||
kname_reduce << "fast_inference_sdpa_reduce_tiles" + delimiter;
|
||||
|
||||
for (const auto& arr : {k, v, out}) {
|
||||
if (arr.dtype() != q.dtype()) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o");
|
||||
}
|
||||
}
|
||||
|
||||
if (q.dtype() == float32) {
|
||||
kname_partials << "float" + delimiter;
|
||||
kname_reduce << "float";
|
||||
} else if (q.dtype() == float16) {
|
||||
kname_partials << "half" + delimiter;
|
||||
kname_reduce << "half";
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16.");
|
||||
}
|
||||
|
||||
std::string kname_suffix_tile_size = std::to_string(tile_size) + delimiter;
|
||||
|
||||
uint nsimd = 8;
|
||||
std::string kname_suffix_nsimdgroups = std::to_string(nsimd);
|
||||
|
||||
// maximum number of splits == 128 at the moment (reserved tile registers in
|
||||
// reduction kernel). this is arbitrary and could be changed in the shader.
|
||||
|
||||
std::string kname_suffix = kname_suffix_tile_size + kname_suffix_nsimdgroups;
|
||||
kname_partials << kname_suffix;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname_partials.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
constexpr const uint batch = 1;
|
||||
MTL::Size grid_dims = MTL::Size(heads, n_tiles, batch);
|
||||
MTL::Size group_dims = MTL::Size(32, nsimd, 1);
|
||||
|
||||
const uint64_t KV_sequence_length = k.shape(-2);
|
||||
const uint query_sequence_length = q.shape(-2);
|
||||
const uint n_q_heads = q.shape(1);
|
||||
const uint n_kv_heads = k.shape(1);
|
||||
|
||||
MLXScaledDotProductAttentionParams params{
|
||||
query_sequence_length, n_q_heads, n_kv_heads, n_tiles, alpha};
|
||||
|
||||
set_array_buffer(compute_encoder, q, 0);
|
||||
set_array_buffer(compute_encoder, k, 1);
|
||||
set_array_buffer(compute_encoder, v, 2);
|
||||
compute_encoder->setBytes(&KV_sequence_length, sizeof(KV_sequence_length), 3);
|
||||
compute_encoder->setBytes(
|
||||
¶ms, sizeof(MLXScaledDotProductAttentionParams), 4);
|
||||
set_array_buffer(compute_encoder, o_partial, 5);
|
||||
set_array_buffer(compute_encoder, p_lse, 6);
|
||||
set_array_buffer(compute_encoder, p_rowmaxes, 7);
|
||||
|
||||
constexpr const uint tgroupMemorySize = 32768;
|
||||
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
{
|
||||
auto kernel_accum = d.get_kernel(kname_reduce.str());
|
||||
compute_encoder->setComputePipelineState(kernel_accum);
|
||||
set_array_buffer(compute_encoder, o_partial, 0);
|
||||
set_array_buffer(compute_encoder, p_lse, 1);
|
||||
set_array_buffer(compute_encoder, p_rowmaxes, 2);
|
||||
compute_encoder->setBytes(
|
||||
¶ms, sizeof(MLXScaledDotProductAttentionParams), 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
|
||||
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
|
||||
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });
|
||||
return;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
assert(inputs.size() >= 3);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention] Does not yet support non-floating point types.");
|
||||
}
|
||||
|
||||
if (inputs.size() == 4) {
|
||||
out = fallback_(inputs)[0];
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& q_pre = inputs[0];
|
||||
auto& k_pre = inputs[1];
|
||||
auto& v_pre = inputs[2];
|
||||
auto& o = out;
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> temporaries;
|
||||
auto check_transpose = [&temporaries, &s](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
temporaries.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto q = check_transpose(q_pre);
|
||||
auto k = check_transpose(k_pre);
|
||||
auto v = check_transpose(v_pre);
|
||||
|
||||
const int heads = q.shape(-3);
|
||||
int tile_size = 64;
|
||||
const int kv_seq_len = k.shape(-2);
|
||||
if (kv_seq_len > 8000) {
|
||||
tile_size = 128;
|
||||
}
|
||||
if (kv_seq_len > 16000) {
|
||||
tile_size = 256;
|
||||
}
|
||||
if (kv_seq_len > 32000) {
|
||||
tile_size = 512;
|
||||
}
|
||||
|
||||
const int n_tiles = (kv_seq_len + tile_size - 1) / tile_size;
|
||||
|
||||
array o_partials(
|
||||
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles * v.shape(-1)},
|
||||
float32,
|
||||
nullptr,
|
||||
{});
|
||||
o_partials.set_data(allocator::malloc_or_wait(o_partials.nbytes()));
|
||||
|
||||
array p_lse(
|
||||
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {});
|
||||
array p_rowmaxes(
|
||||
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {});
|
||||
p_lse.set_data(allocator::malloc_or_wait(p_lse.nbytes()));
|
||||
p_rowmaxes.set_data(allocator::malloc_or_wait(p_rowmaxes.nbytes()));
|
||||
|
||||
temporaries.push_back(p_lse);
|
||||
temporaries.push_back(p_rowmaxes);
|
||||
temporaries.push_back(o_partials);
|
||||
|
||||
return sdpa_metal(
|
||||
s,
|
||||
d,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
p_lse,
|
||||
p_rowmaxes,
|
||||
o_partials,
|
||||
heads,
|
||||
tile_size,
|
||||
n_tiles,
|
||||
scale_,
|
||||
out,
|
||||
temporaries);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@ -99,6 +99,7 @@ NO_GPU(Transpose)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
143
mlx/fast.cpp
143
mlx/fast.cpp
@ -127,4 +127,147 @@ bool RoPE::is_equivalent(const Primitive& other) const {
|
||||
offset_ == a_other.offset_);
|
||||
}
|
||||
|
||||
/** Computes: O = softmax(Q @ K.T) @ V **/
|
||||
array scaled_dot_product_attention(
|
||||
const array& queries,
|
||||
const array& keys,
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::optional<array>& mask,
|
||||
StreamOrDevice s) {
|
||||
for (const auto& tensor : {queries, keys, values}) {
|
||||
if (tensor.ndim() != 4) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] input with shape "
|
||||
<< tensor.shape() << " expected to be rank 4";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
const size_t batch_dim = queries.shape(0);
|
||||
for (const auto& tensor : {keys, values}) {
|
||||
if (tensor.shape(0) != batch_dim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] mismatching batch dimension for input with shape "
|
||||
<< tensor.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
// Q, K must have matching last dims (d_k aka 'head_dim');
|
||||
if (queries.shape(-1) != keys.shape(-1)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape "
|
||||
<< queries.shape() << " for keys shape " << keys.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// K, V must have matching number of heads (n_kv_heads);
|
||||
size_t n_q_heads = queries.shape(-3);
|
||||
size_t n_kv_heads = keys.shape(-3);
|
||||
|
||||
if (keys.shape(-3) != values.shape(-3)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] keys, values expected to have matching n_kv_heads; found keys with n_heads "
|
||||
<< keys.shape(-3) << " for values with n_heads " << values.shape(-3)
|
||||
<< ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.
|
||||
if (n_q_heads % n_kv_heads != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] n_heads must be a multiple of n_kv_heads, found n_heads "
|
||||
<< n_q_heads << " for n_kv_heads " << n_kv_heads << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto final_type = result_type({queries, keys, values});
|
||||
|
||||
auto q = astype(queries, final_type, s);
|
||||
auto k = astype(keys, final_type, s);
|
||||
auto v = astype(values, final_type, s);
|
||||
|
||||
auto out_shape =
|
||||
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
|
||||
|
||||
/* generic implementation for use cases that Metal implementation does not
|
||||
* support. For non-supported cases listed below, use MLX primitives:
|
||||
* * CPU implementation
|
||||
* * batch size > 1
|
||||
* * query sequence length > 1
|
||||
* * non-null mask
|
||||
*/
|
||||
bool needs_mask = mask.has_value();
|
||||
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto& q_tensor = inputs[0];
|
||||
auto& k_tensor = inputs[1];
|
||||
auto& v_tensor = inputs[2];
|
||||
auto q_scaled = multiply(array(scale, q_tensor.dtype()), q_tensor, s);
|
||||
|
||||
auto tile_if_needs_repeat =
|
||||
[n_q_heads, n_kv_heads](const array& arr, StreamOrDevice& s) -> array {
|
||||
if (n_q_heads == n_kv_heads)
|
||||
return arr;
|
||||
int n_repeats = n_q_heads / n_kv_heads;
|
||||
constexpr const int heads_axis =
|
||||
1; // heads axis, assumes tensors arranged as [0, 1, 2, 3] ->
|
||||
// [Batch, Heads, Sequence, Hidden]
|
||||
auto ret = repeat(arr, n_repeats, heads_axis, s);
|
||||
return ret;
|
||||
};
|
||||
auto k_tensor_tiled = tile_if_needs_repeat(k_tensor, s);
|
||||
auto v_tensor_tiled = tile_if_needs_repeat(v_tensor, s);
|
||||
|
||||
// dim check on k, v; repeat if untiled, since naive matmul will have
|
||||
// dim mismatch for GQA (MQA could make use of broadcast)
|
||||
auto k_transposed = transpose(k_tensor_tiled, {0, 1, 3, 2}, s);
|
||||
auto s_tensor = matmul(q_scaled, k_transposed, s);
|
||||
if (needs_mask) {
|
||||
auto mask_tensor = inputs[3];
|
||||
s_tensor = add(s_tensor, mask_tensor, s);
|
||||
}
|
||||
auto p = astype(
|
||||
softmax(astype(s_tensor, float32, s), std::vector<int>{-1}, s),
|
||||
final_type,
|
||||
s);
|
||||
auto out_tensor = matmul(p, v_tensor_tiled, s);
|
||||
return std::vector<array>{out_tensor};
|
||||
};
|
||||
|
||||
auto stream = to_stream(s);
|
||||
|
||||
// current implementation use case: batch size 1, query sequence length 1, no
|
||||
// mask. Likewise, requires head_dim == 128
|
||||
constexpr const int supported_head_dim = 128;
|
||||
const size_t query_head_dim = q.shape(-1);
|
||||
const size_t query_sequence_length = q.shape(2);
|
||||
bool implementation_supports_use_case = batch_dim == 1 &&
|
||||
query_sequence_length == 1 && !mask.has_value() &&
|
||||
query_head_dim == supported_head_dim;
|
||||
|
||||
if (stream.device == Device::gpu && implementation_supports_use_case) {
|
||||
auto out = array(
|
||||
out_shape,
|
||||
final_type,
|
||||
std::make_unique<ScaledDotProductAttention>(
|
||||
stream, fallback, scale, false),
|
||||
{q, k, v});
|
||||
return out;
|
||||
}
|
||||
|
||||
if (mask.has_value()) {
|
||||
return fallback({q, k, v, mask.value()})[0];
|
||||
} else {
|
||||
return fallback({q, k, v})[0];
|
||||
}
|
||||
}
|
||||
|
||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||
const ScaledDotProductAttention& a_other =
|
||||
static_cast<const ScaledDotProductAttention&>(other);
|
||||
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
11
mlx/fast.h
11
mlx/fast.h
@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
@ -15,4 +17,13 @@ array rope(
|
||||
int offset,
|
||||
StreamOrDevice s /* = {} */);
|
||||
|
||||
/** Computes: O = softmax(Q @ K.T) @ V **/
|
||||
array scaled_dot_product_attention(
|
||||
const array& queries,
|
||||
const array& keys,
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::optional<array>& mask = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
@ -65,4 +65,34 @@ class RoPE : public Custom {
|
||||
int offset_;
|
||||
};
|
||||
|
||||
class ScaledDotProductAttention : public Custom {
|
||||
public:
|
||||
explicit ScaledDotProductAttention(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
const float scale,
|
||||
const bool needs_mask)
|
||||
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
outputs[0] = fallback_(inputs)[0];
|
||||
};
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
eval_gpu(inputs, outputs[0]);
|
||||
};
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out);
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
DEFINE_PRINT(ScaledDotProductAttention)
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float scale_;
|
||||
bool needs_mask_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
@ -56,4 +56,44 @@ void init_extensions(py::module_& parent_module) {
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"scaled_dot_product_attention",
|
||||
[](const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
const float scale,
|
||||
const std::optional<array>& mask,
|
||||
const StreamOrDevice& s) {
|
||||
return fast::scaled_dot_product_attention(q, k, v, scale, mask, s);
|
||||
},
|
||||
"q"_a,
|
||||
"k"_a,
|
||||
"v"_a,
|
||||
py::kw_only(),
|
||||
"scale"_a,
|
||||
"mask"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V.
|
||||
Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150).
|
||||
|
||||
This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations.
|
||||
|
||||
Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32).
|
||||
Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array.
|
||||
|
||||
Args:
|
||||
q (array): Input query array.
|
||||
k (array): Input keys array.
|
||||
v (array): Input values array.
|
||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||
mask (array, optional): An additive mask to apply to the query-key scores.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
|
||||
)pbdoc");
|
||||
}
|
||||
|
103
python/tests/test_fast_sdpa.py
Normal file
103
python/tests/test_fast_sdpa.py
Normal file
@ -0,0 +1,103 @@
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
|
||||
# SDPA for MHA (n_heads == n_kv_heads)
|
||||
def mlx_primitives_sdpa(q, k, v, scale):
|
||||
p = (q * scale) @ k.transpose(0, 1, 3, 2)
|
||||
scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype)
|
||||
return scores @ v
|
||||
|
||||
|
||||
# SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0)
|
||||
def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
|
||||
|
||||
n_repeats = q.shape[1] // k.shape[1]
|
||||
|
||||
# borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py
|
||||
n_heads = q.shape[1]
|
||||
B = q.shape[0]
|
||||
L = k.shape[2]
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * n_repeats, axis=2)
|
||||
return a.reshape([B, n_heads, L, -1])
|
||||
|
||||
k, v = map(repeat, (k, v))
|
||||
|
||||
return mlx_primitives_sdpa(q, k, v, scale)
|
||||
|
||||
|
||||
class TestFastInferenceSDPA(mlx_tests.MLXTestCase):
|
||||
@property
|
||||
def dtypes(self):
|
||||
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
|
||||
|
||||
def test_fast_inference_sdpa(self):
|
||||
|
||||
# Not yet supported:
|
||||
# * K pre-transposed in kernel, V pre-transposed in kernel
|
||||
np.random.seed(0)
|
||||
L = 43
|
||||
R = 1
|
||||
Dk = 128
|
||||
scale = float(1.0 / np.sqrt(128.0))
|
||||
q_npy = np.random.normal(0.0, 1.0, (1, 32, R, Dk)).astype(np.float32)
|
||||
k_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32)
|
||||
v_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32)
|
||||
|
||||
q_mlx = mx.array(q_npy)
|
||||
k_mlx = mx.array(k_npy)
|
||||
v_mlx = mx.array(v_npy)
|
||||
|
||||
reference = mlx_primitives_sdpa(q_mlx, k_mlx, v_mlx, scale)
|
||||
|
||||
o_mlx = mx.fast.scaled_dot_product_attention(
|
||||
q_mlx, k_mlx, v_mlx, scale=scale, mask=None
|
||||
)
|
||||
|
||||
self.assertListEqual(list(reference.shape), list(o_mlx.shape))
|
||||
self.assertTrue(mx.allclose(o_mlx, reference, atol=1e-4))
|
||||
|
||||
B = 1
|
||||
H = 32
|
||||
for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]:
|
||||
for DO_GQA in [0, 1]:
|
||||
for DTYPE in [np.float32, np.half]:
|
||||
n_kv_heads = 8 if DO_GQA else 32
|
||||
q_npy = np.random.normal(0.0, 1.0, (B, H, R, Dk)).astype(DTYPE)
|
||||
k_npy = np.random.normal(
|
||||
0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk)
|
||||
).astype(DTYPE)
|
||||
v_npy = np.random.normal(
|
||||
0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk)
|
||||
).astype(DTYPE)
|
||||
|
||||
q_mlx = mx.array(q_npy)
|
||||
k_mlx = mx.array(k_npy)
|
||||
v_mlx = mx.array(v_npy)
|
||||
|
||||
reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale)
|
||||
o_mlx = mx.fast.scaled_dot_product_attention(
|
||||
q_mlx, k_mlx, v_mlx, scale=scale
|
||||
)
|
||||
|
||||
self.assertListEqual(list(reference.shape), list(o_mlx.shape))
|
||||
rtol = 1e-5
|
||||
atol = 1e-1
|
||||
|
||||
if SEQUENCE_LENGTH > 500:
|
||||
rtol = 1e-2
|
||||
|
||||
if DTYPE == np.half:
|
||||
rtol = 1e-2
|
||||
|
||||
self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
Loading…
Reference in New Issue
Block a user