mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Matrix Attention kernel (#1610)
* Rough INIT * [WIP]: Loading and Matmuls added * [WIP]: Reductions and min working aligned kernel at headdim = 64 * [WIP] Added headdim 80 for testing * [WIP] Update dispatch params for testing * [WIP] Add support for unaligned seq lengths - still looks messy * Update sdpa_benchmarks * Update sdpa_benchmarks * Update sdpa_benchmarks * Enable gqa support * Update benchmark and switch off 128 headdim * Update headdim 128 tuning * Remove older fast attention code. Write out O strided * Disable hd=128 until further optimizations * Enable bf16 * Fix data size bug * Enable attn build outside of jit
This commit is contained in:
@@ -44,9 +44,7 @@ build_kernel(layer_norm)
|
||||
build_kernel(random)
|
||||
build_kernel(rms_norm)
|
||||
build_kernel(rope)
|
||||
build_kernel(
|
||||
scaled_dot_product_attention scaled_dot_product_attention_params.h
|
||||
sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)
|
||||
build_kernel(scaled_dot_product_attention sdpa_vector.h)
|
||||
|
||||
set(STEEL_HEADERS
|
||||
steel/defines.h
|
||||
@@ -68,6 +66,24 @@ set(STEEL_HEADERS
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h)
|
||||
|
||||
set(STEEL_ATTN_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
steel/gemm/gemm.h
|
||||
steel/gemm/mma.h
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h
|
||||
steel/attn/attn.h
|
||||
steel/attn/loader.h
|
||||
steel/attn/mma.h
|
||||
steel/attn/params.h
|
||||
steel/attn/transforms.h
|
||||
steel/attn/kernels/steel_attention.h)
|
||||
|
||||
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
|
||||
|
||||
if(NOT MLX_METAL_JIT)
|
||||
build_kernel(arange arange.h)
|
||||
build_kernel(binary binary.h binary_ops.h)
|
||||
|
||||
Reference in New Issue
Block a user