mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Remove masks from BlockLoader and clear out load case for invalid thread (#634)
This commit is contained in:
parent
d40a04f8dc
commit
316ff490b3
@ -89,20 +89,9 @@ struct GEMMKernel {
|
||||
// Appease the compiler
|
||||
(void)l;
|
||||
|
||||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
||||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
||||
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
|
||||
if (!M_aligned) {
|
||||
short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
loader_a.set_mask(tile_dims_A, mask_A);
|
||||
}
|
||||
|
||||
if (!N_aligned) {
|
||||
short2 tile_dims_B =
|
||||
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
loader_b.set_mask(tile_dims_B, mask_B);
|
||||
}
|
||||
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@ -110,13 +99,13 @@ struct GEMMKernel {
|
||||
if (M_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(mask_A);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
}
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(mask_B);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@ -137,11 +126,8 @@ struct GEMMKernel {
|
||||
short2 tile_dims_B_last =
|
||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||
|
||||
loader_a.set_mask(tile_dims_A_last, mask_A);
|
||||
loader_b.set_mask(tile_dims_B_last, mask_B);
|
||||
|
||||
loader_a.load_safe(mask_A);
|
||||
loader_b.load_safe(mask_B);
|
||||
loader_a.load_safe(tile_dims_A_last);
|
||||
loader_b.load_safe(tile_dims_B_last);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@ -218,14 +204,8 @@ struct GEMMKernel {
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
||||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
||||
|
||||
loader_a.set_mask(tile_dims_A, mask_A);
|
||||
loader_b.set_mask(tile_dims_B, mask_B);
|
||||
|
||||
loader_a.load_safe(mask_A);
|
||||
loader_b.load_safe(mask_B);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
|
@ -112,14 +112,8 @@ template <typename T,
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
||||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
||||
|
||||
loader_a.set_mask(tile_dims_A, mask_A);
|
||||
loader_b.set_mask(tile_dims_B, mask_B);
|
||||
|
||||
loader_a.load_safe(mask_A);
|
||||
loader_b.load_safe(mask_B);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
|
@ -67,24 +67,22 @@ struct BlockLoader {
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void set_mask(
|
||||
thread const short2& src_tile_dims,
|
||||
thread bool mask[n_rows][vec_size]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
mask[i][j] =
|
||||
((bi + i) < src_tile_dims.y) && ((bj + j) < src_tile_dims.x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||
|
||||
// Skip loading if thread has no valid reads
|
||||
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Use fast thread memory for bound checks
|
||||
bool tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
@ -117,39 +115,6 @@ struct BlockLoader {
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(const thread bool mask[n_rows][vec_size]) const {
|
||||
T tmp_val[vec_size];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0, ii = 0; i < BROWS; i += TROWS, ii++) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Use fast thread memory for bound checks
|
||||
|
||||
// Read valid indices into tmp_val
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[(mask[ii][j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Zero out uneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = mask[ii][j] ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tile_stride;
|
||||
|
Loading…
Reference in New Issue
Block a user