MLX
 
Loading...
Searching...
No Matches
params.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
6// Attn param classes
8
9namespace mlx {
10namespace steel {
11
12struct AttnParams {
13 int B;
14 int H;
15 int D;
16
17 int qL;
18 int kL;
19
21 float scale;
22
23 int NQ;
24 int NK;
25
28
29 int qL_rem;
30 int kL_rem;
31 int qL_off;
32
33 int64_t Q_strides[3];
34 int64_t K_strides[3];
35 int64_t V_strides[3];
36 int64_t O_strides[3];
37};
38
40 int64_t M_strides[3];
41};
42
43} // namespace steel
44} // namespace mlx
Definition attn.h:19
Definition allocator.h:7
Definition params.h:39
int64_t M_strides[3]
Mask strides (B, H, qL, kL = 1)
Definition params.h:40
Definition params.h:12
int D
Head Dim.
Definition params.h:15
int B
Batch Size.
Definition params.h:13
int qL_off
Offset in query sequence start.
Definition params.h:31
int gqa_factor
Group Query factor.
Definition params.h:20
int H
Heads.
Definition params.h:14
int NQ
Number of query blocks.
Definition params.h:23
int kL
Key Sequence Length.
Definition params.h:18
int NQ_aligned
Number of full query blocks.
Definition params.h:26
int qL
Query Sequence Length.
Definition params.h:17
int NK
Number of key/value blocks.
Definition params.h:24
int kL_rem
Remainder in last key/value block.
Definition params.h:30
int64_t Q_strides[3]
Query strides (B, H, L, D = 1)
Definition params.h:33
int NK_aligned
Number of full key/value blocks.
Definition params.h:27
int64_t O_strides[3]
Output strides (B, H, L, D = 1)
Definition params.h:36
int64_t V_strides[3]
Value strides (B, H, L, D = 1)
Definition params.h:35
float scale
Attention scale.
Definition params.h:21
int qL_rem
Remainder in last query block.
Definition params.h:29
int64_t K_strides[3]
Key strides (B, H, L, D = 1)
Definition params.h:34