22 short tgp_padding = 0>
71 uint simd_group_id [[simdgroup_index_in_threadgroup]],
72 uint simd_lane_id [[thread_index_in_simdgroup]])
73 :
thread_idx(simd_group_id * 32 + simd_lane_id),
84 for (
short i = 0; i <
n_rows; ++i) {
85 int offset_nhw = offsets.y +
bi + i *
TROWS;
108 for (
short i = 0, is = 0; i <
n_rows; ++i, is +=
TROWS) {
124 if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
125 (iw_dil >= 0 && iw < params->iS[1])) {
127 for (
short j = 0; j <
vec_size; ++j) {
128 dst[is *
dst_ld + j] = (
src[i])[offset + j];
135 for (
short j = 0; j <
vec_size; ++j) {
136 dst[is *
dst_ld + j] = T(0);
145 if (weight_w < params->wS[1]) {
152 if (weight_h < params->wS[0]) {
159 for (
short i = 0; i <
n_rows; i++) {
171 short tgp_padding = 0>
180 (BN == 8) ? 1 : (tgp_size / (
BROWS *
BCOLS) >= 8 ? 8 : 4);
214 const device T* src_,
219 const short base_wh_,
220 const short base_ww_,
221 uint simd_group_id [[simdgroup_index_in_threadgroup]],
222 uint simd_lane_id [[thread_index_in_simdgroup]])
223 :
src_ld(params_ -> wt_strides[0]),
224 thread_idx(simd_group_id * 32 + simd_lane_id),
244 for (
short i = 0; i < BN; i +=
TROWS) {
246 for (
short j = 0; j <
vec_size; j++) {
251 for (
short i = 0; i < BN; i +=
TROWS) {
254 for (
short j = 0; j <
vec_size; j++) {
259 for (
short j = 0; j <
vec_size; j++) {
260 dst[i *
dst_ld + j] = T(0);
270 if (weight_w < params->wS[1]) {
277 if (weight_h < params->wS[0]) {
const int kdil[NDIM]
Definition params.h:15
const int str[NDIM]
Definition params.h:13
const size_t wt_strides[NDIM+2]
Definition params.h:18
const bool flip
Definition params.h:21
const size_t in_strides[NDIM+2]
Definition params.h:17
const int wS[NDIM]
Definition params.h:11
const int O
Definition params.h:9
const int pad[NDIM]
Definition params.h:14
const int idil[NDIM]
Definition params.h:16
const int f_out_jump_w
Definition params.h:48
const int f_wgt_jump_h
Definition params.h:44
const int f_wgt_jump_w
Definition params.h:45
const int f_out_jump_h
Definition params.h:47
const int adj_out_w
Definition params.h:51
const int adj_out_hw
Definition params.h:52
Definition loader_general.h:172
STEEL_CONST short BROWS
Definition loader_general.h:174
const short thread_idx
Definition loader_general.h:193
STEEL_CONST short vec_size
Definition loader_general.h:179
METAL_FUNC void next()
Definition loader_general.h:268
STEEL_CONST short BCOLS
Definition loader_general.h:175
const int start_row
Definition loader_general.h:210
const short base_ww
Definition loader_general.h:205
const short bi
Definition loader_general.h:194
const device T * src
Definition loader_general.h:199
short weight_h
Definition loader_general.h:207
const int src_ld
Definition loader_general.h:190
const short base_wh
Definition loader_general.h:204
short weight_w
Definition loader_general.h:208
threadgroup T * dst
Definition loader_general.h:198
METAL_FUNC void load_unsafe() const
Definition loader_general.h:238
const constant Conv2DGeneralJumpParams * jump_params
Definition loader_general.h:202
STEEL_CONST short dst_ld
Definition loader_general.h:178
STEEL_CONST short n_rows
Definition loader_general.h:187
STEEL_CONST short TROWS
Definition loader_general.h:184
const short bj
Definition loader_general.h:195
METAL_FUNC Conv2DWeightBlockLoaderGeneral(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant Conv2DGeneralJumpParams *jump_params_, const short base_wh_, const short base_ww_, uint simd_group_id, uint simd_lane_id)
Definition loader_general.h:213
const constant MLXConvParams< 2 > * params
Definition loader_general.h:201
STEEL_CONST short TCOLS
Definition loader_general.h:183