Fix SDPA kernel bug on Mac OS 13.3 SDK (#805)

* Move sdpa kernel to allocate tgp mem statically and allow macOS 13.3 SDK builds

* Style
This commit is contained in:
Jagrit Digani
2024-03-07 10:18:09 -08:00
committed by GitHub
parent b7588fd5d7
commit ec8a4864fa
3 changed files with 6 additions and 5 deletions

View File

@@ -13,10 +13,12 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
device float* O_partials [[buffer(5)]],
device float* p_lse [[buffer(6)]],
device float* p_maxes [[buffer(7)]],
threadgroup T* threadgroup_block [[threadgroup(0)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
threadgroup T threadgroup_block[32768 / sizeof(T)];
constexpr const size_t DK = 128;
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
@@ -356,7 +358,6 @@ template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_si
device float* O_partials [[buffer(5)]], \
device float* p_lse [[buffer(6)]], \
device float* p_maxes [[buffer(7)]], \
threadgroup itype *threadgroup_block [[threadgroup(0)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]]);

View File

@@ -97,8 +97,6 @@ void sdpa_metal(
set_array_buffer(compute_encoder, p_lse, 6);
set_array_buffer(compute_encoder, p_rowmaxes, 7);
constexpr const uint tgroupMemorySize = 32768;
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
{