MLX
Loading...
Searching...
No Matches
loader.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
6
8// Loading helper
10
11namespace mlx {
12namespace steel {
13
14template <
15 typename T,
16 short BROWS,
17 short BCOLS,
18 short dst_ld,
19 short reduction_dim,
20 short tgp_size,
21 short alignment = 1,
22 short n_reads = (BCOLS * BROWS) / (tgp_size),
23 short TCOLS = BCOLS / n_reads,
24 short TROWS = tgp_size / TCOLS>
26 STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
27 STEEL_CONST short vec_size = n_reads;
28
29 // Leading dimension for src
30 const int src_ld;
31 const int tile_stride;
32
33 // Thread location indices
34 const short thread_idx;
35 const short bi;
36 const short bj;
37
38 // threadgroup and device memory
39 threadgroup T* dst;
40 const device T* src;
41
42 struct alignas(alignment * sizeof(T)) ReadVector {
43 uint8_t v[sizeof(T) * vec_size];
44 };
45
46 /* Constructor */
47 METAL_FUNC BlockLoader(
48 const device T* src_,
49 const int src_ld_,
50 threadgroup T* dst_,
51 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
52 ushort simd_lane_id [[thread_index_in_simdgroup]])
53 : src_ld(src_ld_),
54 tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
55 thread_idx(simd_group_id * 32 + simd_lane_id),
56 bi(thread_idx / TCOLS),
57 bj(vec_size * (thread_idx % TCOLS)),
58 dst(dst_ + bi * dst_ld + bj),
59 src(src_ + bi * src_ld + bj) {}
60
61 /* Load from device memory into threadgroup memory - without bound checking */
62 METAL_FUNC void load_unsafe() const {
64 for (short i = 0; i < BROWS; i += TROWS) {
65 *((threadgroup ReadVector*)(&dst[i * dst_ld])) =
66 *((const device ReadVector*)(&src[i * src_ld]));
67 }
68 }
69
70 /* Load from device memory into threadgroup memory - with bound checking */
71 METAL_FUNC void load_safe(short2 src_tile_dim) const {
72 src_tile_dim = src_tile_dim - short2(bj, bi);
73
74 // Skip loading if thread has no valid reads
75 if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
77 for (short i = 0; i < BROWS; i += TROWS) {
79 for (short j = 0; j < vec_size; j++) {
80 dst[i * dst_ld + j] = T(0);
81 }
82 }
83 return;
84 }
85
86 // Use fast thread memory for bound checks
87 bool tmp_idx[vec_size];
88 T tmp_val[vec_size];
89
91 for (short i = 0; i < BROWS; i += TROWS) {
92 // Make sure tmp_idx only contains valid indices
94 for (short j = 0; j < vec_size; j++) {
95 tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
96 }
97
98 // Read valid indices into tmp_val
100 for (short j = 0; j < vec_size; j++) {
101 tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
102 }
103
104 // Zero out uneeded values
106 for (short j = 0; j < vec_size; j++) {
107 tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
108 }
109
110 // Copy values to threadgroup memory
112 for (short j = 0; j < vec_size; j++) {
113 dst[i * dst_ld + j] = tmp_val[j];
114 }
115 }
116 }
117
118 /* Iteration helper */
119 METAL_FUNC void next() {
120 src += tile_stride;
121 }
122};
123
124} // namespace steel
125} // namespace mlx
#define STEEL_PRAGMA_UNROLL
Definition utils.h:8
#define STEEL_CONST
Definition utils.h:7
Definition allocator.h:7
uint8_t v[sizeof(T) *vec_size]
Definition loader.h:43
Definition loader.h:25
const short thread_idx
Definition loader.h:34
const device T * src
Definition loader.h:40
METAL_FUNC BlockLoader(const device T *src_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition loader.h:47
STEEL_CONST short vec_size
Definition loader.h:27
METAL_FUNC void next()
Definition loader.h:119
METAL_FUNC void load_unsafe() const
Definition loader.h:62
const short bj
Definition loader.h:36
STEEL_CONST short n_rows
Definition loader.h:26
const short bi
Definition loader.h:35
const int src_ld
Definition loader.h:30
const int tile_stride
Definition loader.h:31
METAL_FUNC void load_safe(short2 src_tile_dim) const
Definition loader.h:71
threadgroup T * dst
Definition loader.h:39