mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Add the attention kernel
This commit is contained in:
parent
190c72739b
commit
221edc4a65
@ -109,6 +109,7 @@ if(NOT MLX_METAL_JIT)
|
||||
reduction/reduce_row.h)
|
||||
build_kernel(quantized quantized.h ${STEEL_HEADERS})
|
||||
build_kernel(scan scan.h)
|
||||
build_kernel(paged_attention paged_attention.h)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(logsumexp logsumexp.h)
|
||||
build_kernel(sort sort.h)
|
||||
|
1195
mlx/backend/metal/kernels/paged_attention.h
Normal file
1195
mlx/backend/metal/kernels/paged_attention.h
Normal file
File diff suppressed because it is too large
Load Diff
131
mlx/backend/metal/kernels/paged_attention.metal
Normal file
131
mlx/backend/metal/kernels/paged_attention.metal
Normal file
@ -0,0 +1,131 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/paged_attention.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#define instantiate_paged_attention_inner( \
|
||||
type, head_size, block_size, num_threads, num_simd_lanes, partition_size) \
|
||||
template \
|
||||
[[host_name("paged_attention_" #type "_hs" #head_size "_bs" #block_size \
|
||||
"_nt" #num_threads "_nsl" #num_simd_lanes \
|
||||
"_ps" #partition_size)]] [[kernel]] void \
|
||||
paged_attention< \
|
||||
type, \
|
||||
head_size, \
|
||||
block_size, \
|
||||
num_threads, \
|
||||
num_simd_lanes, \
|
||||
partition_size>( \
|
||||
device float* exp_sums \
|
||||
[[buffer(0), function_constant(use_partitioning)]], \
|
||||
device float* max_logits \
|
||||
[[buffer(1), function_constant(use_partitioning)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
device const type* q [[buffer(3)]], \
|
||||
device const type* k_cache [[buffer(4)]], \
|
||||
device const type* v_cache [[buffer(5)]], \
|
||||
const constant int& num_kv_heads [[buffer(6)]], \
|
||||
const constant float& scale [[buffer(7)]], \
|
||||
const constant float& softcapping [[buffer(8)]], \
|
||||
device const uint32_t* block_tables [[buffer(9)]], \
|
||||
device const uint32_t* context_lens [[buffer(10)]], \
|
||||
const constant int& max_num_blocks_per_seq [[buffer(11)]], \
|
||||
device const float* alibi_slopes \
|
||||
[[buffer(12), function_constant(use_alibi)]], \
|
||||
const constant int& q_stride [[buffer(13)]], \
|
||||
const constant int& kv_block_stride [[buffer(14)]], \
|
||||
const constant int& kv_head_stride [[buffer(15)]], \
|
||||
threadgroup char* shared_mem [[threadgroup(0)]], \
|
||||
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
||||
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
|
||||
uint3 thread_position_in_threadgroup \
|
||||
[[thread_position_in_threadgroup]], \
|
||||
uint simd_tid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, head_size, num_threads, num_simd_lanes, partition_size) \
|
||||
template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \
|
||||
"_nt" #num_threads "_nsl" #num_simd_lanes \
|
||||
"_ps" #partition_size)]] [[kernel]] void \
|
||||
paged_attention_v2_reduce< \
|
||||
type, \
|
||||
head_size, \
|
||||
num_threads, \
|
||||
num_simd_lanes, \
|
||||
partition_size>( \
|
||||
device type * out [[buffer(0)]], \
|
||||
const device float* exp_sums [[buffer(1)]], \
|
||||
const device float* max_logits [[buffer(2)]], \
|
||||
const device type* tmp_out [[buffer(3)]], \
|
||||
device uint32_t* context_lens [[buffer(4)]], \
|
||||
const constant int& max_num_partitions [[buffer(5)]], \
|
||||
threadgroup char* shared_mem [[threadgroup(0)]], \
|
||||
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
||||
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
|
||||
uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \
|
||||
uint3 threads_per_threadgroup [[threads_per_threadgroup]], \
|
||||
uint simd_tid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_paged_attention_heads( \
|
||||
type, block_size, num_threads, num_simd_lanes, partition_size) \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 64, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 80, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 96, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 112, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 128, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 192, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 256, block_size, num_threads, num_simd_lanes, partition_size);
|
||||
|
||||
#define instantiate_paged_attention_v2_reduce_heads( \
|
||||
type, num_threads, num_simd_lanes, partition_size) \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 64, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 80, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 96, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 112, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 128, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 192, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 256, num_threads, num_simd_lanes, partition_size);
|
||||
|
||||
#define instantiate_paged_attention_block_size( \
|
||||
type, num_threads, num_simd_lanes, partition_size) \
|
||||
instantiate_paged_attention_heads( \
|
||||
type, 8, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_heads( \
|
||||
type, 16, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_heads( \
|
||||
type, 32, num_threads, num_simd_lanes, partition_size);
|
||||
|
||||
// TODO: tune num_threads = 256
|
||||
// NOTE: partition_size = 0
|
||||
#define instantiate_paged_attention_v1(type, num_simd_lanes) \
|
||||
instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 0);
|
||||
|
||||
// TODO: tune num_threads = 256
|
||||
// NOTE: partition_size = 512
|
||||
#define instantiate_paged_attention_v2(type, num_simd_lanes) \
|
||||
instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 512); \
|
||||
instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512);
|
||||
|
||||
instantiate_paged_attention_v1(float, 32);
|
||||
instantiate_paged_attention_v1(bfloat16_t, 32);
|
||||
instantiate_paged_attention_v1(half, 32);
|
||||
|
||||
instantiate_paged_attention_v2(float, 32);
|
||||
instantiate_paged_attention_v2(bfloat16_t, 32);
|
||||
instantiate_paged_attention_v2(half, 32);
|
Loading…
Reference in New Issue
Block a user