20 short tgp_padding = 0>
69 uint simd_group_id [[simdgroup_index_in_threadgroup]],
70 uint simd_lane_id [[thread_index_in_simdgroup]])
71 :
thread_idx(simd_group_id * 32 + simd_lane_id),
82 for (
short i = 0; i <
n_rows; ++i) {
83 int offset_nhw = offsets.y +
bi + i *
TROWS;
106 for (
short i = 0, is = 0; i <
n_rows; ++i, is +=
TROWS) {
122 if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
123 (iw_dil >= 0 && iw < params->iS[1])) {
125 for (
short j = 0; j <
vec_size; ++j) {
126 dst[is *
dst_ld + j] = (
src[i])[offset + j];
133 for (
short j = 0; j <
vec_size; ++j) {
134 dst[is *
dst_ld + j] = T(0);
143 if (weight_w < params->wS[1]) {
150 if (weight_h < params->wS[0]) {
157 for (
short i = 0; i <
n_rows; i++) {
169 short tgp_padding = 0>
178 (BN == 8) ? 1 : (tgp_size / (
BROWS *
BCOLS) >= 8 ? 8 : 4);
212 const device T* src_,
217 const short base_wh_,
218 const short base_ww_,
219 uint simd_group_id [[simdgroup_index_in_threadgroup]],
220 uint simd_lane_id [[thread_index_in_simdgroup]])
221 :
src_ld(params_->wt_strides[0]),
222 thread_idx(simd_group_id * 32 + simd_lane_id),
242 for (
short i = 0; i < BN; i +=
TROWS) {
244 for (
short j = 0; j <
vec_size; j++) {
249 for (
short i = 0; i < BN; i +=
TROWS) {
252 for (
short j = 0; j <
vec_size; j++) {
257 for (
short j = 0; j <
vec_size; j++) {
258 dst[i *
dst_ld + j] = T(0);
268 if (weight_w < params->wS[1]) {
275 if (weight_h < params->wS[0]) {
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
const int kdil[NDIM]
Definition params.h:15
const int str[NDIM]
Definition params.h:13
const size_t wt_strides[NDIM+2]
Definition params.h:18
const bool flip
Definition params.h:21
const size_t in_strides[NDIM+2]
Definition params.h:17
const int wS[NDIM]
Definition params.h:11
const int O
Definition params.h:9
const int pad[NDIM]
Definition params.h:14
const int idil[NDIM]
Definition params.h:16
const int f_out_jump_w
Definition params.h:48
const int f_wgt_jump_h
Definition params.h:44
const int f_wgt_jump_w
Definition params.h:45
const int f_out_jump_h
Definition params.h:47
const int adj_out_w
Definition params.h:51
const int adj_out_hw
Definition params.h:52
Definition loader_general.h:170
STEEL_CONST short BROWS
Definition loader_general.h:172
const short thread_idx
Definition loader_general.h:191
STEEL_CONST short vec_size
Definition loader_general.h:177
METAL_FUNC void next()
Definition loader_general.h:266
STEEL_CONST short BCOLS
Definition loader_general.h:173
const int start_row
Definition loader_general.h:208
const short base_ww
Definition loader_general.h:203
const short bi
Definition loader_general.h:192
const device T * src
Definition loader_general.h:197
short weight_h
Definition loader_general.h:205
const int src_ld
Definition loader_general.h:188
const short base_wh
Definition loader_general.h:202
short weight_w
Definition loader_general.h:206
threadgroup T * dst
Definition loader_general.h:196
METAL_FUNC void load_unsafe() const
Definition loader_general.h:236
const constant Conv2DGeneralJumpParams * jump_params
Definition loader_general.h:200
STEEL_CONST short dst_ld
Definition loader_general.h:176
STEEL_CONST short n_rows
Definition loader_general.h:185
STEEL_CONST short TROWS
Definition loader_general.h:182
const short bj
Definition loader_general.h:193
METAL_FUNC Conv2DWeightBlockLoaderGeneral(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant Conv2DGeneralJumpParams *jump_params_, const short base_wh_, const short base_ww_, uint simd_group_id, uint simd_lane_id)
Definition loader_general.h:211
const constant MLXConvParams< 2 > * params
Definition loader_general.h:199
STEEL_CONST short TCOLS
Definition loader_general.h:181