22    short n_reads = (BCOLS * BROWS) / (tgp_size),
 
   23    short TCOLS = BCOLS / n_reads,
 
   24    short TROWS = tgp_size / TCOLS>
 
   51      ushort simd_group_id [[simdgroup_index_in_threadgroup]],
 
   52      ushort simd_lane_id [[thread_index_in_simdgroup]])
 
   58        dst(dst_ + 
bi * dst_ld + 
bj),
 
 
   62  template <
typename UnaryOp>
 
   65    for (
short i = 0; i < BROWS; i += TROWS) {
 
   67      for (
short j = 0; j < 
vec_size; j++) {
 
   68        dst[i * dst_ld + j] = 
op.apply(dst[i * dst_ld + j]);
 
 
   76    for (
short i = 0; i < BROWS; i += TROWS) {
 
   77      *((threadgroup 
ReadVector*)(&dst[i * dst_ld])) =
 
 
   83  METAL_FUNC 
void load_safe(short2 src_tile_dim)
 const {
 
   84    src_tile_dim = src_tile_dim - short2(
bj, 
bi);
 
   87    if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
 
   89      for (
short i = 0; i < BROWS; i += TROWS) {
 
   91        for (
short j = 0; j < 
vec_size; j++) {
 
   92          dst[i * dst_ld + j] = T(0);
 
  103    for (
short i = 0; i < BROWS; i += TROWS) {
 
  106      for (
short j = 0; j < 
vec_size; j++) {
 
  107        tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
 
  112      for (
short j = 0; j < 
vec_size; j++) {
 
  113        tmp_val[j] = 
src[(tmp_idx[j] ? i * 
src_ld + j : 0)];
 
  118      for (
short j = 0; j < 
vec_size; j++) {
 
  119        tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
 
  124      for (
short j = 0; j < 
vec_size; j++) {
 
  125        dst[i * dst_ld + j] = tmp_val[j];
 
 
 
Op op
Definition binary.h:129
 
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
 
#define STEEL_CONST
Definition defines.h:3
 
uint8_t v[sizeof(T) *vec_size]
Definition loader.h:43
 
const short thread_idx
Definition loader.h:34
 
const device T * src
Definition loader.h:40
 
METAL_FUNC BlockLoader(const device T *src_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition loader.h:47
 
STEEL_CONST short vec_size
Definition loader.h:27
 
METAL_FUNC void next()
Definition loader.h:131
 
METAL_FUNC void load_unsafe() const
Definition loader.h:74
 
const short bj
Definition loader.h:36
 
STEEL_CONST short n_rows
Definition loader.h:26
 
const short bi
Definition loader.h:35
 
const int src_ld
Definition loader.h:30
 
const int tile_stride
Definition loader.h:31
 
METAL_FUNC void load_safe(short2 src_tile_dim) const
Definition loader.h:83
 
METAL_FUNC void apply_inplace_op(thread const UnaryOp &op) const
Definition loader.h:63
 
threadgroup T * dst
Definition loader.h:39