MLX
Loading...
Searching...
No Matches
params.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
5template <int NDIM>
7 const int N; // Batch size
8 const int C; // In channels
9 const int O; // Out channels
10 const int iS[NDIM]; // Input spatial dim
11 const int wS[NDIM]; // Weight spatial dim
12 const int oS[NDIM]; // Output spatial dim
13 const int str[NDIM]; // Kernel strides
14 const int pad[NDIM]; // Input padding
15 const int kdil[NDIM]; // Kernel dilation
16 const int idil[NDIM]; // Input dilation
17 const size_t in_strides[NDIM + 2]; // In strides
18 const size_t wt_strides[NDIM + 2]; // Wt strides
19 const size_t out_strides[NDIM + 2]; // Out strides
20 const int groups; // Input channel groups
21 const bool flip;
22};
23
24namespace mlx {
25namespace steel {
26
28 const int M;
29 const int N;
30 const int K;
31
33
34 const int inp_jump_w;
35 const int inp_jump_h;
36 const int inp_jump_c;
37
38 const int tiles_n;
39 const int tiles_m;
40 const int swizzle_log;
41};
42
44 const int f_wgt_jump_h;
45 const int f_wgt_jump_w;
46
47 const int f_out_jump_h;
48 const int f_out_jump_w;
49
50 const int adj_out_h;
51 const int adj_out_w;
52 const int adj_out_hw;
53 const int adj_implicit_m;
54};
55
60
61} // namespace steel
62} // namespace mlx
Definition allocator.h:7
Definition params.h:6
const int C
Definition params.h:8
const size_t out_strides[NDIM+2]
Definition params.h:19
const int oS[NDIM]
Definition params.h:12
const int iS[NDIM]
Definition params.h:10
const int kdil[NDIM]
Definition params.h:15
const int str[NDIM]
Definition params.h:13
const size_t wt_strides[NDIM+2]
Definition params.h:18
const bool flip
Definition params.h:21
const size_t in_strides[NDIM+2]
Definition params.h:17
const int wS[NDIM]
Definition params.h:11
const int O
Definition params.h:9
const int N
Definition params.h:7
const int pad[NDIM]
Definition params.h:14
const int groups
Definition params.h:20
const int idil[NDIM]
Definition params.h:16
Definition params.h:56
int weight_base
Definition params.h:57
int weight_size
Definition params.h:58
const int f_out_jump_w
Definition params.h:48
const int f_wgt_jump_h
Definition params.h:44
const int f_wgt_jump_w
Definition params.h:45
const int adj_implicit_m
Definition params.h:53
const int f_out_jump_h
Definition params.h:47
const int adj_out_h
Definition params.h:50
const int adj_out_w
Definition params.h:51
const int adj_out_hw
Definition params.h:52
const int inp_jump_h
Definition params.h:35
const int M
Definition params.h:28
const int N
Definition params.h:29
const int tiles_m
Definition params.h:39
const int tiles_n
Definition params.h:38
const int inp_jump_c
Definition params.h:36
const int gemm_k_iterations
Definition params.h:32
const int inp_jump_w
Definition params.h:34
const int swizzle_log
Definition params.h:40
const int K
Definition params.h:30