16template <
short n_channels_>
58 short tgp_padding = 0>
101 uint simd_group_id [[simdgroup_index_in_threadgroup]],
102 uint simd_lane_id [[thread_index_in_simdgroup]])
103 :
thread_idx(simd_group_id * 32 + simd_lane_id),
113 for (
short i = 0; i <
n_rows; ++i) {
114 int offset_nhw = offsets.y +
bi + i *
TROWS;
115 int n = offset_nhw / out_n_pixels;
116 int hw = offset_nhw % out_n_pixels;
117 int oh = hw /
params->oS[1];
118 int ow = hw %
params->oS[1];
124 src[i] = src_ + n *
params->in_strides[0] + ih *
params->in_strides[1] +
125 iw *
params->in_strides[2];
139 for (
short j = 0; j <
vec_size; j++) {
149 int flip_h =
params->flip ?
params->wS[0] - wh - 1 : wh;
150 int flip_w =
params->flip ?
params->wS[1] - ww - 1 : ww;
152 int weight_h = flip_h *
params->kdil[0];
153 int weight_w = flip_w *
params->kdil[1];
156 for (
short i = 0, is = 0; i <
n_rows; ++i, is +=
TROWS) {
159 int ih =
read_ih[i] + weight_h;
160 int iw =
read_iw[i] + weight_w;
163 if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
164 (iw >= 0 && iw < params->iS[1])) {
165 const device T* curr_src =
src[i] + weight_h *
params->in_strides[1] +
166 weight_w *
params->in_strides[2];
169 for (
short j = 0; j < n_channels; ++j) {
174 for (
short j = n_channels; j <
vec_size; ++j) {
182 for (
short j = 0; j <
vec_size; ++j) {
202 short tgp_padding = 0>
240 const device T* src_,
245 uint simd_group_id [[simdgroup_index_in_threadgroup]],
246 uint simd_lane_id [[thread_index_in_simdgroup]])
247 :
src_ld(params_->wt_strides[0]),
248 thread_idx(simd_group_id * 32 + simd_lane_id),
267 for (
short j = 0; j <
vec_size; j++) {
281 for (
short j = 0; j < n_channels; j++) {
286 for (
short j = n_channels; j <
vec_size; j++) {
294 for (
short j = 0; j < n_channels; j++) {
299 for (
short j = n_channels; j <
vec_size; j++) {
304 for (
short j = 0; j <
vec_size; j++) {
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
STEEL_CONST short n_channels
Definition loader_channel_n.h:25
STEEL_CONST short vec_size
Definition loader_channel_n.h:26
STEEL_CONST short excess
Definition loader_channel_n.h:27
STEEL_CONST short n_channels
Definition loader_channel_n.h:32
STEEL_CONST short excess
Definition loader_channel_n.h:34
STEEL_CONST short vec_size
Definition loader_channel_n.h:33
STEEL_CONST short n_channels
Definition loader_channel_n.h:39
STEEL_CONST short vec_size
Definition loader_channel_n.h:40
STEEL_CONST short excess
Definition loader_channel_n.h:41
STEEL_CONST short n_channels
Definition loader_channel_n.h:46
STEEL_CONST short excess
Definition loader_channel_n.h:48
STEEL_CONST short vec_size
Definition loader_channel_n.h:47
Definition loader_channel_n.h:17
STEEL_CONST short vec_size
Definition loader_channel_n.h:19
STEEL_CONST short n_channels
Definition loader_channel_n.h:18
STEEL_CONST short excess
Definition loader_channel_n.h:20
STEEL_CONST short vec_size
Definition loader_channel_n.h:210
METAL_FUNC void load_unsafe() const
Definition loader_channel_n.h:259
threadgroup T * dst
Definition loader_channel_n.h:228
METAL_FUNC void next()
Definition loader_channel_n.h:313
int weight_hw
Definition loader_channel_n.h:233
STEEL_CONST short TROWS
Definition loader_channel_n.h:214
const bool do_read
Definition loader_channel_n.h:236
const device T * src
Definition loader_channel_n.h:229
STEEL_CONST short BCOLS
Definition loader_channel_n.h:206
const int read_n
Definition loader_channel_n.h:235
const int src_ld
Definition loader_channel_n.h:220
STEEL_CONST short dst_ld
Definition loader_channel_n.h:209
const short thread_idx
Definition loader_channel_n.h:223
STEEL_CONST short BROWS
Definition loader_channel_n.h:205
STEEL_CONST short TCOLS
Definition loader_channel_n.h:213
const short bj
Definition loader_channel_n.h:225
METAL_FUNC Conv2DWeightBlockLoaderSmallChannels(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant ImplicitGemmConv2DParams *gemm_params_, uint simd_group_id, uint simd_lane_id)
Definition loader_channel_n.h:239
const short bi
Definition loader_channel_n.h:224
STEEL_CONST short n_rows
Definition loader_channel_n.h:217
const constant MLXConvParams< 2 > * params
Definition loader_channel_n.h:231