16template <
short n_channels_>
 
   58    short tgp_padding = 0>
 
  101      uint simd_group_id [[simdgroup_index_in_threadgroup]],
 
  102      uint simd_lane_id [[thread_index_in_simdgroup]])
 
  103      : 
thread_idx(simd_group_id * 32 + simd_lane_id),
 
  113    for (
short i = 0; i < 
n_rows; ++i) {
 
  114      int offset_nhw = offsets.y + 
bi + i * 
TROWS;
 
  115      int n = offset_nhw / out_n_pixels;
 
  116      int hw = offset_nhw % out_n_pixels;
 
 
  139        for (
short j = 0; j < 
vec_size; j++) {
 
  140          dst[i * 
dst_ld + j] = T(0);
 
  156    for (
short i = 0, is = 0; i < 
n_rows; ++i, is += 
TROWS) {
 
  159      int ih = 
read_ih[i] + weight_h;
 
  160      int iw = 
read_iw[i] + weight_w;
 
  163      if ((n < params->N) && (ih >= 0 && ih < 
params->
iS[0]) &&
 
  164          (iw >= 0 && iw < params->iS[1])) {
 
  169        for (
short j = 0; j < n_channels; ++j) {
 
  170          dst[is * 
dst_ld + j] = curr_src[j];
 
  174        for (
short j = n_channels; j < 
vec_size; ++j) {
 
  175          dst[is * 
dst_ld + j] = T(0);
 
  182        for (
short j = 0; j < 
vec_size; ++j) {
 
  183          dst[is * 
dst_ld + j] = T(0);
 
 
 
  202    short tgp_padding = 0>
 
  240      const device T* src_,
 
  245      uint simd_group_id [[simdgroup_index_in_threadgroup]],
 
  246      uint simd_lane_id [[thread_index_in_simdgroup]])
 
  247      : 
src_ld(params_ -> wt_strides[0]),
 
  248        thread_idx(simd_group_id * 32 + simd_lane_id),
 
 
  267        for (
short j = 0; j < 
vec_size; j++) {
 
  268          dst[i * 
dst_ld + j] = T(0);
 
  281        for (
short j = 0; j < n_channels; j++) {
 
  286        for (
short j = n_channels; j < 
vec_size; j++) {
 
  287          dst[i * 
dst_ld + j] = T(0);
 
  294          for (
short j = 0; j < n_channels; j++) {
 
  299          for (
short j = n_channels; j < 
vec_size; j++) {
 
  300            dst[i * 
dst_ld + j] = T(0);
 
  304          for (
short j = 0; j < 
vec_size; j++) {
 
  305            dst[i * 
dst_ld + j] = T(0);
 
 
 
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
 
#define STEEL_CONST
Definition defines.h:3
 
const int oS[NDIM]
Definition params.h:12
 
const int iS[NDIM]
Definition params.h:10
 
const int kdil[NDIM]
Definition params.h:15
 
const int str[NDIM]
Definition params.h:13
 
const size_t wt_strides[NDIM+2]
Definition params.h:18
 
const bool flip
Definition params.h:21
 
const size_t in_strides[NDIM+2]
Definition params.h:17
 
const int wS[NDIM]
Definition params.h:11
 
const int O
Definition params.h:9
 
const int pad[NDIM]
Definition params.h:14
 
Definition loader_channel_n.h:17
 
STEEL_CONST short vec_size
Definition loader_channel_n.h:19
 
STEEL_CONST short n_channels
Definition loader_channel_n.h:18
 
STEEL_CONST short excess
Definition loader_channel_n.h:20
 
Definition loader_channel_n.h:203
 
STEEL_CONST short vec_size
Definition loader_channel_n.h:210
 
METAL_FUNC void load_unsafe() const
Definition loader_channel_n.h:259
 
threadgroup T * dst
Definition loader_channel_n.h:228
 
METAL_FUNC void next()
Definition loader_channel_n.h:313
 
int weight_hw
Definition loader_channel_n.h:233
 
STEEL_CONST short TROWS
Definition loader_channel_n.h:214
 
const bool do_read
Definition loader_channel_n.h:236
 
const device T * src
Definition loader_channel_n.h:229
 
STEEL_CONST short BCOLS
Definition loader_channel_n.h:206
 
const int read_n
Definition loader_channel_n.h:235
 
const int src_ld
Definition loader_channel_n.h:220
 
STEEL_CONST short dst_ld
Definition loader_channel_n.h:209
 
const short thread_idx
Definition loader_channel_n.h:223
 
STEEL_CONST short BROWS
Definition loader_channel_n.h:205
 
STEEL_CONST short TCOLS
Definition loader_channel_n.h:213
 
const short bj
Definition loader_channel_n.h:225
 
METAL_FUNC Conv2DWeightBlockLoaderSmallChannels(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant ImplicitGemmConv2DParams *gemm_params_, uint simd_group_id, uint simd_lane_id)
Definition loader_channel_n.h:239
 
const short bi
Definition loader_channel_n.h:224
 
STEEL_CONST short n_rows
Definition loader_channel_n.h:217
 
const constant MLXConvParams< 2 > * params
Definition loader_channel_n.h:231