16template <
short n_channels_>
58 short tgp_padding = 0>
101 uint simd_group_id [[simdgroup_index_in_threadgroup]],
102 uint simd_lane_id [[thread_index_in_simdgroup]])
103 :
thread_idx(simd_group_id * 32 + simd_lane_id),
113 for (
short i = 0; i <
n_rows; ++i) {
114 int offset_nhw = offsets.y +
bi + i *
TROWS;
115 int n = offset_nhw / out_n_pixels;
116 int hw = offset_nhw % out_n_pixels;
139 for (
short j = 0; j <
vec_size; j++) {
140 dst[i *
dst_ld + j] = T(0);
156 for (
short i = 0, is = 0; i <
n_rows; ++i, is +=
TROWS) {
159 int ih =
read_ih[i] + weight_h;
160 int iw =
read_iw[i] + weight_w;
163 if ((n < params->N) && (ih >= 0 && ih <
params->
iS[0]) &&
164 (iw >= 0 && iw < params->iS[1])) {
169 for (
short j = 0; j < n_channels; ++j) {
170 dst[is *
dst_ld + j] = curr_src[j];
174 for (
short j = n_channels; j <
vec_size; ++j) {
175 dst[is *
dst_ld + j] = T(0);
182 for (
short j = 0; j <
vec_size; ++j) {
183 dst[is *
dst_ld + j] = T(0);
202 short tgp_padding = 0>
240 const device T* src_,
245 uint simd_group_id [[simdgroup_index_in_threadgroup]],
246 uint simd_lane_id [[thread_index_in_simdgroup]])
247 :
src_ld(params_->wt_strides[0]),
248 thread_idx(simd_group_id * 32 + simd_lane_id),
267 for (
short j = 0; j <
vec_size; j++) {
268 dst[i *
dst_ld + j] = T(0);
281 for (
short j = 0; j < n_channels; j++) {
286 for (
short j = n_channels; j <
vec_size; j++) {
287 dst[i *
dst_ld + j] = T(0);
294 for (
short j = 0; j < n_channels; j++) {
299 for (
short j = n_channels; j <
vec_size; j++) {
300 dst[i *
dst_ld + j] = T(0);
304 for (
short j = 0; j <
vec_size; j++) {
305 dst[i *
dst_ld + j] = T(0);
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
const int oS[NDIM]
Definition params.h:12
const int iS[NDIM]
Definition params.h:10
const int kdil[NDIM]
Definition params.h:15
const int str[NDIM]
Definition params.h:13
const size_t wt_strides[NDIM+2]
Definition params.h:18
const bool flip
Definition params.h:21
const size_t in_strides[NDIM+2]
Definition params.h:17
const int wS[NDIM]
Definition params.h:11
const int O
Definition params.h:9
const int pad[NDIM]
Definition params.h:14
Definition loader_channel_n.h:17
STEEL_CONST short vec_size
Definition loader_channel_n.h:19
STEEL_CONST short n_channels
Definition loader_channel_n.h:18
STEEL_CONST short excess
Definition loader_channel_n.h:20
Definition loader_channel_n.h:203
STEEL_CONST short vec_size
Definition loader_channel_n.h:210
METAL_FUNC void load_unsafe() const
Definition loader_channel_n.h:259
threadgroup T * dst
Definition loader_channel_n.h:228
METAL_FUNC void next()
Definition loader_channel_n.h:313
int weight_hw
Definition loader_channel_n.h:233
STEEL_CONST short TROWS
Definition loader_channel_n.h:214
const bool do_read
Definition loader_channel_n.h:236
const device T * src
Definition loader_channel_n.h:229
STEEL_CONST short BCOLS
Definition loader_channel_n.h:206
const int read_n
Definition loader_channel_n.h:235
const int src_ld
Definition loader_channel_n.h:220
STEEL_CONST short dst_ld
Definition loader_channel_n.h:209
const short thread_idx
Definition loader_channel_n.h:223
STEEL_CONST short BROWS
Definition loader_channel_n.h:205
STEEL_CONST short TCOLS
Definition loader_channel_n.h:213
const short bj
Definition loader_channel_n.h:225
METAL_FUNC Conv2DWeightBlockLoaderSmallChannels(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant ImplicitGemmConv2DParams *gemm_params_, uint simd_group_id, uint simd_lane_id)
Definition loader_channel_n.h:239
const short bi
Definition loader_channel_n.h:224
STEEL_CONST short n_rows
Definition loader_channel_n.h:217
const constant MLXConvParams< 2 > * params
Definition loader_channel_n.h:231