22 short tgp_padding = 0>
66 uint simd_group_id [[simdgroup_index_in_threadgroup]],
67 uint simd_lane_id [[thread_index_in_simdgroup]])
68 :
thread_idx(simd_group_id * 32 + simd_lane_id),
79 for (
short i = 0; i <
n_rows; ++i) {
80 int offset_nhw = offsets.y +
bi + i *
TROWS;
81 int n = offset_nhw / out_n_pixels;
82 int hw = offset_nhw % out_n_pixels;
83 int oh = hw /
params->oS[1];
84 int ow = hw %
params->oS[1];
100 src[i] = src_ + n *
params->in_strides[0] + ih *
params->in_strides[1] +
108 for (
short i = 0, is = 0; i <
n_rows; ++i, is +=
TROWS) {
115 if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
116 (iw >= 0 && iw < params->iS[1])) {
118 for (
short j = 0; j <
vec_size; ++j) {
126 for (
short j = 0; j <
vec_size; ++j) {
137 for (
short i = 0; i <
n_rows; i++) {
148 for (
short i = 0; i <
n_rows; i++) {
158 for (
short i = 0; i <
n_rows; i++) {
170 short tgp_padding = 0>
210 const device T* src_,
215 uint simd_group_id [[simdgroup_index_in_threadgroup]],
216 uint simd_lane_id [[thread_index_in_simdgroup]])
217 :
thread_idx(simd_group_id * 32 + simd_lane_id),
232 for (
short i = 0; i <
n_rows; ++i) {
233 int offset_nhw = offsets.y +
bi + i *
TROWS;
234 int n = offset_nhw / out_n_pixels;
235 int hw = offset_nhw % out_n_pixels;
236 int oh = hw /
params->oS[1];
237 int ow = hw %
params->oS[1];
253 src[i] = src_ + n *
params->in_strides[0] + ih *
params->in_strides[1] +
258 for (
short i = 0; i <
n_rows; ++i) {
263 for (
short kh = 0; kh <
params->wS[0]; kh++) {
264 short flip_h =
params->flip ?
params->wS[0] - kh - 1 : kh;
266 for (
short i = 0; i <
n_rows; ++i) {
268 int ih = read_ih[i] + flip_h *
params->kdil[0];
270 bool in_bounds = n <
params->N && ih >= 0 && ih <
params->iS[0];
272 mask_h[i] |= (in_bounds << kh);
276 for (
short kw = 0; kw <
params->wS[1]; kw++) {
277 short flip_w =
params->flip ?
params->wS[1] - kw - 1 : kw;
279 for (
short i = 0; i <
n_rows; ++i) {
280 int iw = read_iw[i] + flip_w *
params->kdil[1];
282 bool in_bounds = iw >= 0 && iw <
params->iS[1];
284 mask_w[i] |= (in_bounds << kw);
295 for (
short i = 0, is = 0; i <
n_rows; ++i, is +=
TROWS) {
299 for (
short j = 0; j <
vec_size; ++j) {
307 for (
short j = 0; j <
vec_size; ++j) {
318 for (
short i = 0; i <
n_rows; i++) {
329 for (
short i = 0; i <
n_rows; i++) {
339 for (
short i = 0; i <
n_rows; i++) {
351 short tgp_padding = 0>
360 (BN == 8) ? 1 : (tgp_size / (
BROWS *
BCOLS) >= 8 ? 8 : 4);
390 const device T* src_,
395 uint simd_group_id [[simdgroup_index_in_threadgroup]],
396 uint simd_lane_id [[thread_index_in_simdgroup]])
397 :
src_ld(params_->wt_strides[0]),
398 thread_idx(simd_group_id * 32 + simd_lane_id),
412 for (
short i = 0; i < BN; i +=
TROWS) {
414 for (
short j = 0; j <
vec_size; j++) {
419 for (
short i = 0; i < BN; i +=
TROWS) {
422 for (
short j = 0; j <
vec_size; j++) {
427 for (
short j = 0; j <
vec_size; j++) {
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
STEEL_CONST short dst_ld
Definition loader_channel_l.h:358
STEEL_CONST short vec_size
Definition loader_channel_l.h:359
const bool do_read
Definition loader_channel_l.h:386
const constant MLXConvParams< 2 > * params
Definition loader_channel_l.h:381
STEEL_CONST short n_rows
Definition loader_channel_l.h:367
const int read_n
Definition loader_channel_l.h:385
METAL_FUNC void load_unsafe() const
Definition loader_channel_l.h:409
const short bj
Definition loader_channel_l.h:375
const int src_ld
Definition loader_channel_l.h:370
const device T * src
Definition loader_channel_l.h:379
STEEL_CONST short TCOLS
Definition loader_channel_l.h:363
STEEL_CONST short BCOLS
Definition loader_channel_l.h:355
const short bi
Definition loader_channel_l.h:374
STEEL_CONST short TROWS
Definition loader_channel_l.h:364
METAL_FUNC Conv2DWeightBlockLoader(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant ImplicitGemmConv2DParams *gemm_params_, uint simd_group_id, uint simd_lane_id)
Definition loader_channel_l.h:389
METAL_FUNC void next()
Definition loader_channel_l.h:436
const short thread_idx
Definition loader_channel_l.h:373
int weight_hw
Definition loader_channel_l.h:383
STEEL_CONST short BROWS
Definition loader_channel_l.h:354
threadgroup T * dst
Definition loader_channel_l.h:378