MLX
 
Loading...
Searching...
No Matches
params.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
6// Attn param classes
8
9namespace mlx {
10namespace steel {
11
12struct AttnParams {
13 int B;
14 int H;
15 int D;
16
17 int qL;
18 int kL;
19
21 float scale;
22
23 int NQ;
24 int NK;
25
28
29 int64_t Q_strides[3];
30 int64_t K_strides[3];
31 int64_t V_strides[3];
32 int64_t O_strides[3];
33};
34
35} // namespace steel
36} // namespace mlx
Definition attn.h:19
Definition allocator.h:7
Definition params.h:12
int D
Head Dim.
Definition params.h:15
int B
Batch Size.
Definition params.h:13
int gqa_factor
Group Query factor.
Definition params.h:20
int H
Heads.
Definition params.h:14
int NQ
Number of query blocks.
Definition params.h:23
int kL
Key Sequence Length.
Definition params.h:18
int NQ_aligned
Number of full query blocks.
Definition params.h:26
int qL
Query Sequence Length.
Definition params.h:17
int NK
Number of key/value blocks.
Definition params.h:24
int64_t Q_strides[3]
Query strides (B, H, L, D = 1)
Definition params.h:29
int NK_aligned
Number of full key/value blocks.
Definition params.h:27
int64_t O_strides[3]
Output strides (B, H, L, D = 1)
Definition params.h:32
int64_t V_strides[3]
Value strides (B, H, L, D = 1)
Definition params.h:31
float scale
Attention scale.
Definition params.h:21
int64_t K_strides[3]
Key strides (B, H, L, D = 1)
Definition params.h:30