22 short n_reads = (BCOLS * BROWS) / (tgp_size),
23 short TCOLS = BCOLS / n_reads,
24 short TROWS = tgp_size / TCOLS>
42 struct alignas(alignment * sizeof(T)) ReadVector {
51 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
52 ushort simd_lane_id [[thread_index_in_simdgroup]])
62 template <
typename UnaryOp>
65 for (
short i = 0; i < BROWS; i += TROWS) {
67 for (
short j = 0; j <
vec_size; j++) {
68 dst[i * dst_ld + j] = op.apply(
dst[i * dst_ld + j]);
76 for (
short i = 0; i < BROWS; i += TROWS) {
77 *((threadgroup ReadVector*)(&
dst[i * dst_ld])) =
78 *((
const device ReadVector*)(&
src[i *
src_ld]));
83 METAL_FUNC
void load_safe(short2 src_tile_dim)
const {
84 src_tile_dim = src_tile_dim - short2(
bj,
bi);
87 if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
89 for (
short i = 0; i < BROWS; i += TROWS) {
91 for (
short j = 0; j <
vec_size; j++) {
92 dst[i * dst_ld + j] = T(0);
103 for (
short i = 0; i < BROWS; i += TROWS) {
106 for (
short j = 0; j <
vec_size; j++) {
107 tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
112 for (
short j = 0; j <
vec_size; j++) {
113 tmp_val[j] =
src[(tmp_idx[j] ? i *
src_ld + j : 0)];
118 for (
short j = 0; j <
vec_size; j++) {
119 tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
124 for (
short j = 0; j <
vec_size; j++) {
125 dst[i * dst_ld + j] = tmp_val[j];
136template <
int R,
int C>
150 short n_reads = (BCOLS * BROWS) / (tgp_size),
151 short TCOLS = BCOLS / n_reads,
152 short TROWS = tgp_size / TCOLS>
172 const device T* src_,
175 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
176 ushort simd_lane_id [[thread_index_in_simdgroup]])
179 thread_idx(simd_group_id * 32 + simd_lane_id),
182 dst(dst_ +
bi * kDstStrRow +
bj * kDstStrCol),
186 template <
typename UnaryOp>
189 for (
short i = 0; i < BROWS; i += TROWS) {
191 for (
short j = 0; j <
vec_size; j++) {
192 dst[i * kDstStrRow + j * kDstStrCol] =
193 op.apply(
dst[i * kDstStrRow + j * kDstStrCol]);
201 for (
short i = 0; i < BROWS; i += TROWS) {
203 for (
short j = 0; j <
vec_size; j++) {
204 dst[i * kDstStrRow + j * kDstStrCol] =
src[i *
src_ld + j];
211 src_tile_dim = src_tile_dim - short2(
bj,
bi);
214 if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
216 for (
short i = 0; i < BROWS; i += TROWS) {
218 for (
short j = 0; j <
vec_size; j++) {
219 dst[i * kDstStrRow + j * kDstStrCol] = T(0);
230 for (
short i = 0; i < BROWS; i += TROWS) {
233 for (
short j = 0; j <
vec_size; j++) {
234 tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
239 for (
short j = 0; j <
vec_size; j++) {
240 tmp_val[j] =
src[(tmp_idx[j] ? i *
src_ld + j : 0)];
245 for (
short j = 0; j <
vec_size; j++) {
246 tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
251 for (
short j = 0; j <
vec_size; j++) {
252 dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j];
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
uint8_t v[sizeof(T) *vec_size]
Definition loader.h:43
const short thread_idx
Definition loader.h:34
METAL_FUNC BlockLoader(const device T *src_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition loader.h:47
STEEL_CONST short vec_size
Definition loader.h:27
METAL_FUNC void next()
Definition loader.h:131
METAL_FUNC void load_unsafe() const
Definition loader.h:74
const short bj
Definition loader.h:36
STEEL_CONST short n_rows
Definition loader.h:26
const short bi
Definition loader.h:35
const int src_ld
Definition loader.h:30
const int tile_stride
Definition loader.h:31
METAL_FUNC void load_safe(short2 src_tile_dim) const
Definition loader.h:83
const device T * src
Definition loader.h:40
METAL_FUNC void apply_inplace_op(thread const UnaryOp &op) const
Definition loader.h:63
threadgroup T * dst
Definition loader.h:39
METAL_FUNC BlockLoaderT(const device T *src_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition loader.h:171
STEEL_CONST short n_rows
Definition loader.h:154
METAL_FUNC void apply_inplace_op(thread const UnaryOp &op) const
Definition loader.h:187
const int tile_stride
Definition loader.h:159
METAL_FUNC void next()
Definition loader.h:258
const short bi
Definition loader.h:163
threadgroup T * dst
Definition loader.h:167
const device T * src
Definition loader.h:168
STEEL_CONST short vec_size
Definition loader.h:155
METAL_FUNC void load_safe(short2 src_tile_dim) const
Definition loader.h:210
const short bj
Definition loader.h:164
METAL_FUNC void load_unsafe() const
Definition loader.h:199
const int src_ld
Definition loader.h:158
const short thread_idx
Definition loader.h:162
STEEL_CONST int kCols
Definition loader.h:139
STEEL_CONST int kRows
Definition loader.h:138