diff --git a/mlx/backend/metal/kernels/steel/gemm/gemm.h b/mlx/backend/metal/kernels/steel/gemm/gemm.h index 3a8f0280c..be70bcacb 100644 --- a/mlx/backend/metal/kernels/steel/gemm/gemm.h +++ b/mlx/backend/metal/kernels/steel/gemm/gemm.h @@ -89,20 +89,9 @@ struct GEMMKernel { // Appease the compiler (void)l; - thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size]; - thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size]; + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); - if (!M_aligned) { - short2 tile_dims_A = - transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); - loader_a.set_mask(tile_dims_A, mask_A); - } - - if (!N_aligned) { - short2 tile_dims_B = - transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); - loader_b.set_mask(tile_dims_B, mask_B); - } + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); @@ -110,13 +99,13 @@ struct GEMMKernel { if (M_aligned) { loader_a.load_unsafe(); } else { - loader_a.load_safe(mask_A); + loader_a.load_safe(tile_dims_A); } if (N_aligned) { loader_b.load_unsafe(); } else { - loader_b.load_safe(mask_B); + loader_b.load_safe(tile_dims_B); } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -137,11 +126,8 @@ struct GEMMKernel { short2 tile_dims_B_last = transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); - loader_a.set_mask(tile_dims_A_last, mask_A); - loader_b.set_mask(tile_dims_B_last, mask_B); - - loader_a.load_safe(mask_A); - loader_b.load_safe(mask_B); + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); threadgroup_barrier(mem_flags::mem_threadgroup); @@ -218,14 +204,8 @@ struct GEMMKernel { short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); - thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size]; - thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size]; - - loader_a.set_mask(tile_dims_A, mask_A); - loader_b.set_mask(tile_dims_B, mask_B); - - loader_a.load_safe(mask_A); - loader_b.load_safe(mask_B); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); threadgroup_barrier(mem_flags::mem_threadgroup); diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal index a6889e3c8..b8e131f0e 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal @@ -112,14 +112,8 @@ template