20    short tgp_padding = 0>
 
   69      uint simd_group_id [[simdgroup_index_in_threadgroup]],
 
   70      uint simd_lane_id [[thread_index_in_simdgroup]])
 
   71      : 
thread_idx(simd_group_id * 32 + simd_lane_id),
 
   82    for (
short i = 0; i < 
n_rows; ++i) {
 
   83      int offset_nhw = offsets.y + 
bi + i * 
TROWS;
 
 
  106    for (
short i = 0, is = 0; i < 
n_rows; ++i, is += 
TROWS) {
 
  122      if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
 
  123          (iw_dil >= 0 && iw < params->iS[1])) {
 
  125        for (
short j = 0; j < 
vec_size; ++j) {
 
  126          dst[is * 
dst_ld + j] = (
src[i])[offset + j];
 
  133        for (
short j = 0; j < 
vec_size; ++j) {
 
  134          dst[is * 
dst_ld + j] = T(0);
 
 
  157    for (
short i = 0; i < 
n_rows; i++) {
 
 
 
  169    short tgp_padding = 0>
 
  178      (BN == 8) ? 1 : (tgp_size / (
BROWS * 
BCOLS) >= 8 ? 8 : 4);
 
  212      const device T* src_,
 
  217      const short base_wh_,
 
  218      const short base_ww_,
 
  219      uint simd_group_id [[simdgroup_index_in_threadgroup]],
 
  220      uint simd_lane_id [[thread_index_in_simdgroup]])
 
  221      : 
src_ld(params_->wt_strides[0]),
 
  222        thread_idx(simd_group_id * 32 + simd_lane_id),
 
 
  242      for (
short i = 0; i < BN; i += 
TROWS) {
 
  244        for (
short j = 0; j < 
vec_size; j++) {
 
  249      for (
short i = 0; i < BN; i += 
TROWS) {
 
  252          for (
short j = 0; j < 
vec_size; j++) {
 
  257          for (
short j = 0; j < 
vec_size; j++) {
 
  258            dst[i * 
dst_ld + j] = T(0);
 
 
 
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
 
#define STEEL_CONST
Definition defines.h:3
 
const int kdil[NDIM]
Definition params.h:15
 
const int str[NDIM]
Definition params.h:13
 
const size_t wt_strides[NDIM+2]
Definition params.h:18
 
const bool flip
Definition params.h:21
 
const size_t in_strides[NDIM+2]
Definition params.h:17
 
const int wS[NDIM]
Definition params.h:11
 
const int O
Definition params.h:9
 
const int pad[NDIM]
Definition params.h:14
 
const int idil[NDIM]
Definition params.h:16
 
const int f_out_jump_w
Definition params.h:48
 
const int f_wgt_jump_h
Definition params.h:44
 
const int f_wgt_jump_w
Definition params.h:45
 
const int f_out_jump_h
Definition params.h:47
 
const int adj_out_w
Definition params.h:51
 
const int adj_out_hw
Definition params.h:52
 
Definition loader_general.h:170
 
STEEL_CONST short BROWS
Definition loader_general.h:172
 
const short thread_idx
Definition loader_general.h:191
 
STEEL_CONST short vec_size
Definition loader_general.h:177
 
METAL_FUNC void next()
Definition loader_general.h:266
 
STEEL_CONST short BCOLS
Definition loader_general.h:173
 
const int start_row
Definition loader_general.h:208
 
const short base_ww
Definition loader_general.h:203
 
const short bi
Definition loader_general.h:192
 
const device T * src
Definition loader_general.h:197
 
short weight_h
Definition loader_general.h:205
 
const int src_ld
Definition loader_general.h:188
 
const short base_wh
Definition loader_general.h:202
 
short weight_w
Definition loader_general.h:206
 
threadgroup T * dst
Definition loader_general.h:196
 
METAL_FUNC void load_unsafe() const
Definition loader_general.h:236
 
const constant Conv2DGeneralJumpParams * jump_params
Definition loader_general.h:200
 
STEEL_CONST short dst_ld
Definition loader_general.h:176
 
STEEL_CONST short n_rows
Definition loader_general.h:185
 
STEEL_CONST short TROWS
Definition loader_general.h:182
 
const short bj
Definition loader_general.h:193
 
METAL_FUNC Conv2DWeightBlockLoaderGeneral(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant Conv2DGeneralJumpParams *jump_params_, const short base_wh_, const short base_ww_, uint simd_group_id, uint simd_lane_id)
Definition loader_general.h:211
 
const constant MLXConvParams< 2 > * params
Definition loader_general.h:199
 
STEEL_CONST short TCOLS
Definition loader_general.h:181