MLX
 
Loading...
Searching...
No Matches
loader.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
6
8// Loading helper
10
11namespace mlx {
12namespace steel {
13
14template <
15 typename T,
16 short BROWS,
17 short BCOLS,
18 short dst_ld,
19 short reduction_dim,
20 short tgp_size,
21 short alignment = 1,
22 short n_reads = (BCOLS * BROWS) / (tgp_size),
23 short TCOLS = BCOLS / n_reads,
24 short TROWS = tgp_size / TCOLS>
25struct BlockLoader {
26 STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
27 STEEL_CONST short vec_size = n_reads;
28
29 // Leading dimension for src
30 const int src_ld;
31 const int tile_stride;
32
33 // Thread location indices
34 const short thread_idx;
35 const short bi;
36 const short bj;
37
38 // threadgroup and device memory
39 threadgroup T* dst;
40 const device T* src;
41
42 struct alignas(alignment * sizeof(T)) ReadVector {
43 uint8_t v[sizeof(T) * vec_size];
44 };
45
46 /* Constructor */
47 METAL_FUNC BlockLoader(
48 const device T* src_,
49 const int src_ld_,
50 threadgroup T* dst_,
51 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
52 ushort simd_lane_id [[thread_index_in_simdgroup]])
53 : src_ld(src_ld_),
54 tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
55 thread_idx(simd_group_id * 32 + simd_lane_id),
56 bi(thread_idx / TCOLS),
57 bj(vec_size * (thread_idx % TCOLS)),
58 dst(dst_ + bi * dst_ld + bj),
59 src(src_ + bi * src_ld + bj) {}
60
61 /* Apply operation to threadgroup without bound checking */
62 template <typename UnaryOp>
63 METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
65 for (short i = 0; i < BROWS; i += TROWS) {
67 for (short j = 0; j < vec_size; j++) {
68 dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
69 }
70 }
71 }
72
73 /* Load from device memory into threadgroup memory - without bound checking */
74 METAL_FUNC void load_unsafe() const {
76 for (short i = 0; i < BROWS; i += TROWS) {
77 *((threadgroup ReadVector*)(&dst[i * dst_ld])) =
78 *((const device ReadVector*)(&src[i * src_ld]));
79 }
80 }
81
82 /* Load from device memory into threadgroup memory - with bound checking */
83 METAL_FUNC void load_safe(short2 src_tile_dim) const {
84 src_tile_dim = src_tile_dim - short2(bj, bi);
85
86 // Skip loading if thread has no valid reads
87 if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
89 for (short i = 0; i < BROWS; i += TROWS) {
91 for (short j = 0; j < vec_size; j++) {
92 dst[i * dst_ld + j] = T(0);
93 }
94 }
95 return;
96 }
97
98 // Use fast thread memory for bound checks
99 bool tmp_idx[vec_size];
100 T tmp_val[vec_size];
101
103 for (short i = 0; i < BROWS; i += TROWS) {
104 // Make sure tmp_idx only contains valid indices
106 for (short j = 0; j < vec_size; j++) {
107 tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
108 }
109
110 // Read valid indices into tmp_val
112 for (short j = 0; j < vec_size; j++) {
113 tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
114 }
115
116 // Zero out uneeded values
118 for (short j = 0; j < vec_size; j++) {
119 tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
120 }
121
122 // Copy values to threadgroup memory
124 for (short j = 0; j < vec_size; j++) {
125 dst[i * dst_ld + j] = tmp_val[j];
126 }
127 }
128 }
129
130 /* Iteration helper */
131 METAL_FUNC void next() {
132 src += tile_stride;
133 }
134};
135
136} // namespace steel
137} // namespace mlx
Device & device(mlx::core::Device)
Definition attn.h:19
Definition allocator.h:7
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
uint8_t v[sizeof(T) *vec_size]
Definition loader.h:43
Definition loader.h:25
const short thread_idx
Definition loader.h:34
METAL_FUNC BlockLoader(const device T *src_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition loader.h:47
STEEL_CONST short vec_size
Definition loader.h:27
METAL_FUNC void next()
Definition loader.h:131
METAL_FUNC void load_unsafe() const
Definition loader.h:74
const short bj
Definition loader.h:36
STEEL_CONST short n_rows
Definition loader.h:26
const short bi
Definition loader.h:35
const int src_ld
Definition loader.h:30
const int tile_stride
Definition loader.h:31
METAL_FUNC void load_safe(short2 src_tile_dim) const
Definition loader.h:83
const device T * src
Definition loader.h:40
METAL_FUNC void apply_inplace_op(thread const UnaryOp &op) const
Definition loader.h:63
threadgroup T * dst
Definition loader.h:39