mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
Remove masks from BlockLoader and clear out load case for invalid thread (#634)
This commit is contained in:
parent
d40a04f8dc
commit
316ff490b3
@ -89,20 +89,9 @@ struct GEMMKernel {
|
|||||||
// Appease the compiler
|
// Appease the compiler
|
||||||
(void)l;
|
(void)l;
|
||||||
|
|
||||||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
|
||||||
|
|
||||||
if (!M_aligned) {
|
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||||
short2 tile_dims_A =
|
|
||||||
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
|
||||||
loader_a.set_mask(tile_dims_A, mask_A);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!N_aligned) {
|
|
||||||
short2 tile_dims_B =
|
|
||||||
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
|
||||||
loader_b.set_mask(tile_dims_B, mask_B);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
@ -110,13 +99,13 @@ struct GEMMKernel {
|
|||||||
if (M_aligned) {
|
if (M_aligned) {
|
||||||
loader_a.load_unsafe();
|
loader_a.load_unsafe();
|
||||||
} else {
|
} else {
|
||||||
loader_a.load_safe(mask_A);
|
loader_a.load_safe(tile_dims_A);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (N_aligned) {
|
if (N_aligned) {
|
||||||
loader_b.load_unsafe();
|
loader_b.load_unsafe();
|
||||||
} else {
|
} else {
|
||||||
loader_b.load_safe(mask_B);
|
loader_b.load_safe(tile_dims_B);
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
@ -137,11 +126,8 @@ struct GEMMKernel {
|
|||||||
short2 tile_dims_B_last =
|
short2 tile_dims_B_last =
|
||||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||||
|
|
||||||
loader_a.set_mask(tile_dims_A_last, mask_A);
|
loader_a.load_safe(tile_dims_A_last);
|
||||||
loader_b.set_mask(tile_dims_B_last, mask_B);
|
loader_b.load_safe(tile_dims_B_last);
|
||||||
|
|
||||||
loader_a.load_safe(mask_A);
|
|
||||||
loader_b.load_safe(mask_B);
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
@ -218,14 +204,8 @@ struct GEMMKernel {
|
|||||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||||
|
|
||||||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
loader_a.load_safe(tile_dims_A);
|
||||||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
|
||||||
loader_a.set_mask(tile_dims_A, mask_A);
|
|
||||||
loader_b.set_mask(tile_dims_B, mask_B);
|
|
||||||
|
|
||||||
loader_a.load_safe(mask_A);
|
|
||||||
loader_b.load_safe(mask_B);
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
@ -112,14 +112,8 @@ template <typename T,
|
|||||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||||
|
|
||||||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
loader_a.load_safe(tile_dims_A);
|
||||||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
|
||||||
loader_a.set_mask(tile_dims_A, mask_A);
|
|
||||||
loader_b.set_mask(tile_dims_B, mask_B);
|
|
||||||
|
|
||||||
loader_a.load_safe(mask_A);
|
|
||||||
loader_b.load_safe(mask_B);
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
@ -67,24 +67,22 @@ struct BlockLoader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Load from device memory into threadgroup memory - without bound checking */
|
|
||||||
METAL_FUNC void set_mask(
|
|
||||||
thread const short2& src_tile_dims,
|
|
||||||
thread bool mask[n_rows][vec_size]) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < n_rows; i++) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
mask[i][j] =
|
|
||||||
((bi + i) < src_tile_dims.y) && ((bj + j) < src_tile_dims.x);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Load from device memory into threadgroup memory - with bound checking */
|
/* Load from device memory into threadgroup memory - with bound checking */
|
||||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||||
|
|
||||||
|
// Skip loading if thread has no valid reads
|
||||||
|
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < BROWS; i += TROWS) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
dst[i * dst_ld + j] = T(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Use fast thread memory for bound checks
|
// Use fast thread memory for bound checks
|
||||||
bool tmp_idx[vec_size];
|
bool tmp_idx[vec_size];
|
||||||
T tmp_val[vec_size];
|
T tmp_val[vec_size];
|
||||||
@ -117,39 +115,6 @@ struct BlockLoader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Load from device memory into threadgroup memory - with bound checking */
|
|
||||||
METAL_FUNC void load_safe(const thread bool mask[n_rows][vec_size]) const {
|
|
||||||
T tmp_val[vec_size];
|
|
||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0, ii = 0; i < BROWS; i += TROWS, ii++) {
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
// Use fast thread memory for bound checks
|
|
||||||
|
|
||||||
// Read valid indices into tmp_val
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
tmp_val[j] = src[(mask[ii][j] ? i * src_ld + j : 0)];
|
|
||||||
}
|
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Zero out uneeded values
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
tmp_val[j] = mask[ii][j] ? tmp_val[j] : T(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Copy values to threadgroup memory
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
dst[i * dst_ld + j] = tmp_val[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Iteration helper */
|
/* Iteration helper */
|
||||||
METAL_FUNC void next() {
|
METAL_FUNC void next() {
|
||||||
src += tile_stride;
|
src += tile_stride;
|
||||||
|
Loading…
Reference in New Issue
Block a user