This commit is contained in:
Angelos Katharopoulos
2024-11-02 21:25:46 -07:00
committed by GitHub
parent 09bc32f62f
commit 62f297b51d
4 changed files with 17 additions and 4 deletions

View File

@@ -936,6 +936,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
const constant int& gqa_factor, \
const constant int& N, \
const constant size_t& k_stride, \
const constant size_t& v_stride, \
const constant float& scale, \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \

View File

@@ -13,6 +13,7 @@ template <typename T, int D>
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_stride,
const constant size_t& v_stride,
const constant float& scale,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -38,7 +39,7 @@ template <typename T, int D>
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
out += head_idx * D + simd_gid * elem_per_thread;
// Read the query and 0 the output accumulator

View File

@@ -162,7 +162,8 @@ void sdpa_vector(
int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2);
int B = q.shape(0) * q.shape(1);
size_t stride = k.strides()[1];
size_t k_stride = k.strides()[1];
size_t v_stride = v.strides()[1];
MTL::Size group_dims(1024, 1, 1);
MTL::Size grid_dims(1, B, 1);
@@ -178,8 +179,9 @@ void sdpa_vector(
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&gqa_factor, sizeof(int), 4);
compute_encoder->setBytes(&N, sizeof(int), 5);
compute_encoder->setBytes(&stride, sizeof(size_t), 6);
compute_encoder->setBytes(&scale, sizeof(float), 7);
compute_encoder->setBytes(&k_stride, sizeof(size_t), 6);
compute_encoder->setBytes(&v_stride, sizeof(size_t), 7);
compute_encoder->setBytes(&scale, sizeof(float), 8);
// Launch
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);