MLX
 
Loading...
Searching...
No Matches
params.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
6// GEMM param classes
8
9namespace mlx {
10namespace steel {
11
12struct GEMMParams {
13 const int M;
14 const int N;
15 const int K;
16
17 const int lda;
18 const int ldb;
19 const int ldd;
20
21 const int tiles_n;
22 const int tiles_m;
23
24 const int64_t batch_stride_a;
25 const int64_t batch_stride_b;
26 const int64_t batch_stride_d;
27
28 const int swizzle_log;
30
31 const int batch_ndim;
32};
33
35 const int M;
36 const int N;
37 const int K;
38
39 const int lda;
40 const int ldb;
41 const int ldc;
42
43 const int tiles_n;
44 const int tiles_m;
45
49
51};
52
54 const int ldc;
55 const int fdc;
56
57 const int64_t batch_stride_c;
58
59 const float alpha;
60 const float beta;
61};
62
63} // namespace steel
64} // namespace mlx
Definition attn.h:19
Definition allocator.h:7
Definition params.h:53
const int fdc
Definition params.h:55
const int ldc
Definition params.h:54
const float beta
Definition params.h:60
const int64_t batch_stride_c
Definition params.h:57
const float alpha
Definition params.h:59
Definition params.h:12
const int gemm_k_iterations_aligned
Definition params.h:29
const int tiles_n
Definition params.h:21
const int N
Definition params.h:14
const int64_t batch_stride_d
Definition params.h:26
const int64_t batch_stride_b
Definition params.h:25
const int ldb
Definition params.h:18
const int batch_ndim
Definition params.h:31
const int ldd
Definition params.h:19
const int M
Definition params.h:13
const int K
Definition params.h:15
const int64_t batch_stride_a
Definition params.h:24
const int tiles_m
Definition params.h:22
const int swizzle_log
Definition params.h:28
const int lda
Definition params.h:17
Definition params.h:34
const int tiles_m
Definition params.h:44
const int N
Definition params.h:36
const int split_k_partition_stride
Definition params.h:47
const int K
Definition params.h:37
const int tiles_n
Definition params.h:43
const int lda
Definition params.h:39
const int ldb
Definition params.h:40
const int ldc
Definition params.h:41
const int M
Definition params.h:35
const int split_k_partition_size
Definition params.h:48
const int gemm_k_iterations_aligned
Definition params.h:50
const int split_k_partitions
Definition params.h:46