int64_t M_strides[3]
Mask strides (B, H, qL, kL = 1)
Definition params.h:40
int D
Head Dim.
Definition params.h:15
int B
Batch Size.
Definition params.h:13
int qL_off
Offset in query sequence start.
Definition params.h:31
int gqa_factor
Group Query factor.
Definition params.h:20
int H
Heads.
Definition params.h:14
int NQ
Number of query blocks.
Definition params.h:23
int kL
Key Sequence Length.
Definition params.h:18
int NQ_aligned
Number of full query blocks.
Definition params.h:26
int qL
Query Sequence Length.
Definition params.h:17
int NK
Number of key/value blocks.
Definition params.h:24
int kL_rem
Remainder in last key/value block.
Definition params.h:30
int64_t Q_strides[3]
Query strides (B, H, L, D = 1)
Definition params.h:33
int NK_aligned
Number of full key/value blocks.
Definition params.h:27
int64_t O_strides[3]
Output strides (B, H, L, D = 1)
Definition params.h:36
int64_t V_strides[3]
Value strides (B, H, L, D = 1)
Definition params.h:35
float scale
Attention scale.
Definition params.h:21
int qL_rem
Remainder in last query block.
Definition params.h:29
int64_t K_strides[3]
Key strides (B, H, L, D = 1)
Definition params.h:34