22    short tgp_padding = 0>
 
   66      uint simd_group_id [[simdgroup_index_in_threadgroup]],
 
   67      uint simd_lane_id [[thread_index_in_simdgroup]])
 
   68      : 
thread_idx(simd_group_id * 32 + simd_lane_id),
 
   79    for (
short i = 0; i < 
n_rows; ++i) {
 
   80      int offset_nhw = offsets.y + 
bi + i * 
TROWS;
 
   81      int n = offset_nhw / out_n_pixels;
 
   82      int hw = offset_nhw % out_n_pixels;
 
 
  108    for (
short i = 0, is = 0; i < 
n_rows; ++i, is += 
TROWS) {
 
  115      if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
 
  116          (iw >= 0 && iw < params->iS[1])) {
 
  118        for (
short j = 0; j < 
vec_size; ++j) {
 
  126        for (
short j = 0; j < 
vec_size; ++j) {
 
  127          dst[is * 
dst_ld + j] = T(0);
 
 
  135    if (++weight_w < params->wS[1]) {
 
  137      for (
short i = 0; i < 
n_rows; i++) {
 
  146    if (++weight_h < params->wS[0]) {
 
  148      for (
short i = 0; i < 
n_rows; i++) {
 
  158    for (
short i = 0; i < 
n_rows; i++) {
 
 
 
  170    short tgp_padding = 0>
 
  210      const device T* src_,
 
  215      uint simd_group_id [[simdgroup_index_in_threadgroup]],
 
  216      uint simd_lane_id [[thread_index_in_simdgroup]])
 
  217      : 
thread_idx(simd_group_id * 32 + simd_lane_id),
 
  232    for (
short i = 0; i < 
n_rows; ++i) {
 
  233      int offset_nhw = offsets.y + 
bi + i * 
TROWS;
 
  234      int n = offset_nhw / out_n_pixels;
 
  235      int hw = offset_nhw % out_n_pixels;
 
  258    for (
short i = 0; i < 
n_rows; ++i) {
 
  263    for (
short kh = 0; kh < 
params->
wS[0]; kh++) {
 
  266      for (
short i = 0; i < 
n_rows; ++i) {
 
  268        int ih = read_ih[i] + flip_h * 
params->
kdil[0];
 
  272        mask_h[i] |= (in_bounds << kh);
 
  276    for (
short kw = 0; kw < 
params->
wS[1]; kw++) {
 
  279      for (
short i = 0; i < 
n_rows; ++i) {
 
  280        int iw = read_iw[i] + flip_w * 
params->
kdil[1];
 
  282        bool in_bounds = iw >= 0 && iw < 
params->
iS[1];
 
  284        mask_w[i] |= (in_bounds << kw);
 
 
  295    for (
short i = 0, is = 0; i < 
n_rows; ++i, is += 
TROWS) {
 
  299        for (
short j = 0; j < 
vec_size; ++j) {
 
  307        for (
short j = 0; j < 
vec_size; ++j) {
 
  308          dst[is * 
dst_ld + j] = T(0);
 
 
  316    if (++weight_w < params->wS[1]) {
 
  318      for (
short i = 0; i < 
n_rows; i++) {
 
  327    if (++weight_h < params->wS[0]) {
 
  329      for (
short i = 0; i < 
n_rows; i++) {
 
  339    for (
short i = 0; i < 
n_rows; i++) {
 
 
 
  351    short tgp_padding = 0>
 
  360      (BN == 8) ? 1 : (tgp_size / (
BROWS * 
BCOLS) >= 8 ? 8 : 4);
 
  390      const device T* src_,
 
  395      uint simd_group_id [[simdgroup_index_in_threadgroup]],
 
  396      uint simd_lane_id [[thread_index_in_simdgroup]])
 
  397      : 
src_ld(params_ -> wt_strides[0]),
 
  398        thread_idx(simd_group_id * 32 + simd_lane_id),
 
 
  412      for (
short i = 0; i < BN; i += 
TROWS) {
 
  414        for (
short j = 0; j < 
vec_size; j++) {
 
  419      for (
short i = 0; i < BN; i += 
TROWS) {
 
  422          for (
short j = 0; j < 
vec_size; j++) {
 
  427          for (
short j = 0; j < 
vec_size; j++) {
 
  428            dst[i * 
dst_ld + j] = T(0);
 
 
 
 
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
 
#define STEEL_CONST
Definition defines.h:3
 
const int oS[NDIM]
Definition params.h:12
 
const int iS[NDIM]
Definition params.h:10
 
const int kdil[NDIM]
Definition params.h:15
 
const int str[NDIM]
Definition params.h:13
 
const size_t wt_strides[NDIM+2]
Definition params.h:18
 
const bool flip
Definition params.h:21
 
const size_t in_strides[NDIM+2]
Definition params.h:17
 
const int wS[NDIM]
Definition params.h:11
 
const int O
Definition params.h:9
 
const int N
Definition params.h:7
 
const int pad[NDIM]
Definition params.h:14
 
Definition loader_channel_l.h:352
 
STEEL_CONST short dst_ld
Definition loader_channel_l.h:358
 
STEEL_CONST short vec_size
Definition loader_channel_l.h:359
 
const bool do_read
Definition loader_channel_l.h:386
 
const constant MLXConvParams< 2 > * params
Definition loader_channel_l.h:381
 
STEEL_CONST short n_rows
Definition loader_channel_l.h:367
 
const int read_n
Definition loader_channel_l.h:385
 
METAL_FUNC void load_unsafe() const
Definition loader_channel_l.h:409
 
const short bj
Definition loader_channel_l.h:375
 
const int src_ld
Definition loader_channel_l.h:370
 
const device T * src
Definition loader_channel_l.h:379
 
STEEL_CONST short TCOLS
Definition loader_channel_l.h:363
 
STEEL_CONST short BCOLS
Definition loader_channel_l.h:355
 
const short bi
Definition loader_channel_l.h:374
 
STEEL_CONST short TROWS
Definition loader_channel_l.h:364
 
METAL_FUNC Conv2DWeightBlockLoader(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant ImplicitGemmConv2DParams *gemm_params_, uint simd_group_id, uint simd_lane_id)
Definition loader_channel_l.h:389
 
METAL_FUNC void next()
Definition loader_channel_l.h:436
 
const short thread_idx
Definition loader_channel_l.h:373
 
int weight_hw
Definition loader_channel_l.h:383
 
STEEL_CONST short BROWS
Definition loader_channel_l.h:354
 
threadgroup T * dst
Definition loader_channel_l.h:378
 
const int inp_jump_h
Definition params.h:35
 
const int inp_jump_c
Definition params.h:36
 
const int inp_jump_w
Definition params.h:34