22 short tgp_padding = 0>
66 uint simd_group_id [[simdgroup_index_in_threadgroup]],
67 uint simd_lane_id [[thread_index_in_simdgroup]])
68 :
thread_idx(simd_group_id * 32 + simd_lane_id),
79 for (
short i = 0; i <
n_rows; ++i) {
80 int offset_nhw = offsets.y +
bi + i *
TROWS;
81 int n = offset_nhw / out_n_pixels;
82 int hw = offset_nhw % out_n_pixels;
108 for (
short i = 0, is = 0; i <
n_rows; ++i, is +=
TROWS) {
115 if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
116 (iw >= 0 && iw < params->iS[1])) {
118 for (
short j = 0; j <
vec_size; ++j) {
126 for (
short j = 0; j <
vec_size; ++j) {
127 dst[is *
dst_ld + j] = T(0);
137 for (
short i = 0; i <
n_rows; i++) {
148 for (
short i = 0; i <
n_rows; i++) {
158 for (
short i = 0; i <
n_rows; i++) {
170 short tgp_padding = 0>
210 const device T* src_,
215 uint simd_group_id [[simdgroup_index_in_threadgroup]],
216 uint simd_lane_id [[thread_index_in_simdgroup]])
217 :
thread_idx(simd_group_id * 32 + simd_lane_id),
232 for (
short i = 0; i <
n_rows; ++i) {
233 int offset_nhw = offsets.y +
bi + i *
TROWS;
234 int n = offset_nhw / out_n_pixels;
235 int hw = offset_nhw % out_n_pixels;
258 for (
short i = 0; i <
n_rows; ++i) {
263 for (
short kh = 0; kh <
params->
wS[0]; kh++) {
266 for (
short i = 0; i <
n_rows; ++i) {
268 int ih = read_ih[i] + flip_h *
params->
kdil[0];
272 mask_h[i] |= (in_bounds << kh);
276 for (
short kw = 0; kw <
params->
wS[1]; kw++) {
279 for (
short i = 0; i <
n_rows; ++i) {
280 int iw = read_iw[i] + flip_w *
params->
kdil[1];
282 bool in_bounds = iw >= 0 && iw <
params->
iS[1];
284 mask_w[i] |= (in_bounds << kw);
295 for (
short i = 0, is = 0; i <
n_rows; ++i, is +=
TROWS) {
299 for (
short j = 0; j <
vec_size; ++j) {
307 for (
short j = 0; j <
vec_size; ++j) {
308 dst[is *
dst_ld + j] = T(0);
318 for (
short i = 0; i <
n_rows; i++) {
329 for (
short i = 0; i <
n_rows; i++) {
339 for (
short i = 0; i <
n_rows; i++) {
351 short tgp_padding = 0>
360 (BN == 8) ? 1 : (tgp_size / (
BROWS *
BCOLS) >= 8 ? 8 : 4);
390 const device T* src_,
395 uint simd_group_id [[simdgroup_index_in_threadgroup]],
396 uint simd_lane_id [[thread_index_in_simdgroup]])
397 :
src_ld(params_->wt_strides[0]),
398 thread_idx(simd_group_id * 32 + simd_lane_id),
412 for (
short i = 0; i < BN; i +=
TROWS) {
414 for (
short j = 0; j <
vec_size; j++) {
419 for (
short i = 0; i < BN; i +=
TROWS) {
422 for (
short j = 0; j <
vec_size; j++) {
427 for (
short j = 0; j <
vec_size; j++) {
428 dst[i *
dst_ld + j] = T(0);
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
const int oS[NDIM]
Definition params.h:12
const int iS[NDIM]
Definition params.h:10
const int kdil[NDIM]
Definition params.h:15
const int str[NDIM]
Definition params.h:13
const size_t wt_strides[NDIM+2]
Definition params.h:18
const bool flip
Definition params.h:21
const size_t in_strides[NDIM+2]
Definition params.h:17
const int wS[NDIM]
Definition params.h:11
const int O
Definition params.h:9
const int N
Definition params.h:7
const int pad[NDIM]
Definition params.h:14
Definition loader_channel_l.h:352
STEEL_CONST short dst_ld
Definition loader_channel_l.h:358
STEEL_CONST short vec_size
Definition loader_channel_l.h:359
const bool do_read
Definition loader_channel_l.h:386
const constant MLXConvParams< 2 > * params
Definition loader_channel_l.h:381
STEEL_CONST short n_rows
Definition loader_channel_l.h:367
const int read_n
Definition loader_channel_l.h:385
METAL_FUNC void load_unsafe() const
Definition loader_channel_l.h:409
const short bj
Definition loader_channel_l.h:375
const int src_ld
Definition loader_channel_l.h:370
const device T * src
Definition loader_channel_l.h:379
STEEL_CONST short TCOLS
Definition loader_channel_l.h:363
STEEL_CONST short BCOLS
Definition loader_channel_l.h:355
const short bi
Definition loader_channel_l.h:374
STEEL_CONST short TROWS
Definition loader_channel_l.h:364
METAL_FUNC Conv2DWeightBlockLoader(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant ImplicitGemmConv2DParams *gemm_params_, uint simd_group_id, uint simd_lane_id)
Definition loader_channel_l.h:389
METAL_FUNC void next()
Definition loader_channel_l.h:436
const short thread_idx
Definition loader_channel_l.h:373
int weight_hw
Definition loader_channel_l.h:383
STEEL_CONST short BROWS
Definition loader_channel_l.h:354
threadgroup T * dst
Definition loader_channel_l.h:378
const int inp_jump_h
Definition params.h:35
const int inp_jump_c
Definition params.h:36
const int inp_jump_w
Definition params.h:34