22 short n_reads = (BCOLS * BROWS) / (tgp_size),
23 short TCOLS = BCOLS / n_reads,
24 short TROWS = tgp_size / TCOLS>
51 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
52 ushort simd_lane_id [[thread_index_in_simdgroup]])
58 dst(dst_ +
bi * dst_ld +
bj),
62 template <
typename UnaryOp>
65 for (
short i = 0; i < BROWS; i += TROWS) {
67 for (
short j = 0; j <
vec_size; j++) {
68 dst[i * dst_ld + j] =
op.apply(dst[i * dst_ld + j]);
76 for (
short i = 0; i < BROWS; i += TROWS) {
77 *((threadgroup
ReadVector*)(&dst[i * dst_ld])) =
83 METAL_FUNC
void load_safe(short2 src_tile_dim)
const {
84 src_tile_dim = src_tile_dim - short2(
bj,
bi);
87 if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
89 for (
short i = 0; i < BROWS; i += TROWS) {
91 for (
short j = 0; j <
vec_size; j++) {
92 dst[i * dst_ld + j] = T(0);
103 for (
short i = 0; i < BROWS; i += TROWS) {
106 for (
short j = 0; j <
vec_size; j++) {
107 tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
112 for (
short j = 0; j <
vec_size; j++) {
113 tmp_val[j] =
src[(tmp_idx[j] ? i *
src_ld + j : 0)];
118 for (
short j = 0; j <
vec_size; j++) {
119 tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
124 for (
short j = 0; j <
vec_size; j++) {
125 dst[i * dst_ld + j] = tmp_val[j];
Op op
Definition binary.h:141
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
uint8_t v[sizeof(T) *vec_size]
Definition loader.h:43
const short thread_idx
Definition loader.h:34
const device T * src
Definition loader.h:40
METAL_FUNC BlockLoader(const device T *src_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition loader.h:47
STEEL_CONST short vec_size
Definition loader.h:27
METAL_FUNC void next()
Definition loader.h:131
METAL_FUNC void load_unsafe() const
Definition loader.h:74
const short bj
Definition loader.h:36
STEEL_CONST short n_rows
Definition loader.h:26
const short bi
Definition loader.h:35
const int src_ld
Definition loader.h:30
const int tile_stride
Definition loader.h:31
METAL_FUNC void load_safe(short2 src_tile_dim) const
Definition loader.h:83
METAL_FUNC void apply_inplace_op(thread const UnaryOp &op) const
Definition loader.h:63
threadgroup T * dst
Definition loader.h:39