diff --git a/mlx/backend/common/masked_mm.cpp b/mlx/backend/common/masked_mm.cpp index 655e260e5..44a471168 100644 --- a/mlx/backend/common/masked_mm.cpp +++ b/mlx/backend/common/masked_mm.cpp @@ -17,24 +17,25 @@ namespace mlx::core { namespace { -template +template inline void mask_matrix( T* data, - const bool* mask, + const mask_t* mask, int block_size, const int X, const int Y, const size_t X_data_str, const size_t Y_data_str, const size_t X_mask_str, - const size_t Y_mask_str) { + const size_t Y_mask_str, + const size_t mask_offset) { int tX = (X + block_size - 1) / block_size; int tY = (Y + block_size - 1) / block_size; for (int i = 0; i < tX; i++) { for (int j = 0; j < tY; j++) { - bool do_mask = mask[i * X_mask_str + j * Y_mask_str]; - if (!do_mask) { + mask_t do_mask = mask[mask_offset + i * X_mask_str + j * Y_mask_str]; + if (do_mask != 1) { int loc_x = i * block_size; int loc_y = j * block_size; T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str; @@ -43,7 +44,11 @@ inline void mask_matrix( int size_y = std::min(block_size, Y - loc_y); for (int ii = 0; ii < size_x; ii++) { for (int jj = 0; jj < size_y; jj++) { - data_block[ii * X_data_str + jj * Y_data_str] = T(0.); + if constexpr (std::is_same_v) { + data_block[ii * X_data_str + jj * Y_data_str] = T(0.); + } else { + data_block[ii * X_data_str + jj * Y_data_str] *= do_mask; + } } } } @@ -62,36 +67,39 @@ void BlockMaskedMM::eval(const std::vector& inputs, array& out) { auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; - auto& out_mask = inputs[2]; - auto check_transpose = [](const array& arr, bool do_copy) { - auto stx = arr.strides()[arr.ndim() - 2]; - auto sty = arr.strides()[arr.ndim() - 1]; - if (stx == arr.shape(-1) && sty == 1) { - if (do_copy) { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::Vector); - return std::make_tuple(false, stx, arr_copy); - } - return std::make_tuple(false, stx, arr); - } else if (stx == 1 && sty == arr.shape(-2)) { - if (do_copy) { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::Vector); - return std::make_tuple(true, sty, arr_copy); - } - return std::make_tuple(true, sty, arr); - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::General); - size_t stx = arr.shape(-1); - return std::make_tuple(false, stx, arr_copy); - } - }; + auto check_transpose = + [](const array& arr, bool do_copy, bool expand_all = false) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (!expand_all && stx == arr.shape(-1) && sty == 1) { + if (do_copy) { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::Vector); + return std::make_tuple(false, stx, arr_copy); + } + return std::make_tuple(false, stx, arr); + } else if (!expand_all && stx == 1 && sty == arr.shape(-2)) { + if (do_copy) { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::Vector); + return std::make_tuple(true, sty, arr_copy); + } + return std::make_tuple(true, sty, arr); + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::General); + size_t stx = arr.shape(-1); + return std::make_tuple(false, stx, arr_copy); + } + }; bool has_op_mask = inputs.size() > 3; - auto [a_transposed, lda, a] = check_transpose(a_pre, has_op_mask); - auto [b_transposed, ldb, b] = check_transpose(b_pre, has_op_mask); + bool has_out_mask = inputs.size() == 3 || inputs.size() == 5; + auto [a_transposed, lda, a] = + check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_); + auto [b_transposed, ldb, b] = + check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_); size_t M = a.shape(-2); size_t N = b.shape(-1); @@ -114,27 +122,42 @@ void BlockMaskedMM::eval(const std::vector& inputs, array& out) { int Y, size_t X_data_str, size_t Y_data_str) { - const bool* mask_ptr = mask.data() + - elem_to_loc(mask.shape(-1) * mask.shape(-2) * batch_idx, - mask.shape(), - mask.strides()); + size_t mask_offset = elem_to_loc( + mask.shape(-1) * mask.shape(-2) * batch_idx, + mask.shape(), + mask.strides()); size_t X_mask_str = mask.strides()[mask.ndim() - 2]; size_t Y_mask_str = mask.strides()[mask.ndim() - 1]; - return mask_matrix( - data, - mask_ptr, - block_size, - X, - Y, - X_data_str, - Y_data_str, - X_mask_str, - Y_mask_str); + if (mask.dtype() == bool_) { + return mask_matrix( + data, + mask.data(), + block_size, + X, + Y, + X_data_str, + Y_data_str, + X_mask_str, + Y_mask_str, + mask_offset); + } else { + return mask_matrix( + data, + mask.data(), + block_size, + X, + Y, + X_data_str, + Y_data_str, + X_mask_str, + Y_mask_str, + mask_offset); + } }; - for (int i = 0; i < (a.size() / (M * K)); ++i) { + for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) { // Adjust pointer float* ai = a.data() + elem_to_loc(M * K * i, a.shape(), a.strides()); @@ -144,7 +167,7 @@ void BlockMaskedMM::eval(const std::vector& inputs, array& out) { // Zero out blocks in a and b if needed if (has_op_mask) { - auto& a_mask = inputs[3]; + auto& a_mask = inputs[inputs.size() - 2]; mask_array( a_mask, ai, @@ -155,7 +178,7 @@ void BlockMaskedMM::eval(const std::vector& inputs, array& out) { a_transposed ? 1 : lda, a_transposed ? lda : 1); - auto& b_mask = inputs[4]; + auto& b_mask = inputs[inputs.size() - 1]; mask_array( b_mask, bi, @@ -186,7 +209,9 @@ void BlockMaskedMM::eval(const std::vector& inputs, array& out) { ); // Zero out blocks in out - mask_array(out_mask, ci, block_size_, i, M, N, N, 1); + if (has_out_mask) { + mask_array(inputs[2], ci, block_size_, i, M, N, N, 1); + } } } diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal index 586595782..d39b4b005 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal @@ -11,8 +11,38 @@ using namespace mlx::steel; // GEMM kernels /////////////////////////////////////////////////////////////////////////////// +struct _NoMask { + char x; + + constexpr METAL_FUNC operator bool() { + return true; + } + constexpr METAL_FUNC operator bool() const threadgroup { + return true; + } + constexpr METAL_FUNC operator bool() const device { + return true; + } + constexpr METAL_FUNC operator bool() const constant { + return true; + } +}; + +template +struct ScaleOp { + OutT scale; + + METAL_FUNC OutT apply(InT x) const { + return static_cast(x) * scale; + } +}; + +typedef struct _NoMask nomask_t; + template < typename T, + typename out_mask_t, + typename op_mask_t, int BM, int BN, int BK, @@ -21,8 +51,7 @@ template < bool transpose_a, bool transpose_b, bool MN_aligned, - bool K_aligned, - bool has_operand_mask = false> + bool K_aligned> [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void block_masked_gemm( const device T* A [[buffer(0)]], @@ -31,9 +60,9 @@ block_masked_gemm( const constant GEMMParams* params [[buffer(4)]], const constant int* batch_shape [[buffer(6)]], const constant size_t* batch_strides [[buffer(7)]], - const device bool* out_mask [[buffer(10)]], - const device bool* lhs_mask [[buffer(11)]], - const device bool* rhs_mask [[buffer(12)]], + const device out_mask_t* out_mask [[buffer(10)]], + const device op_mask_t* lhs_mask [[buffer(11)]], + const device op_mask_t* rhs_mask [[buffer(12)]], const constant int* mask_strides [[buffer(13)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], @@ -42,6 +71,21 @@ block_masked_gemm( // Appease the compiler (void)lid; + static_assert( + BM == BN, + "block_masked_gemm must have the same block M and block N size"); + static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0"); + + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + + constexpr bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + constexpr bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + + constexpr short k_mask_factor = short(BM / BK); + using gemm_kernel = GEMMKernel< T, T, @@ -63,15 +107,19 @@ block_masked_gemm( return; } + const constant size_t* mask_batch_strides = + batch_strides + 2 * params->batch_ndim; + if (params->batch_ndim > 1) { - const constant size_t* mask_batch_strides = - batch_strides + 2 * params->batch_ndim; - out_mask += - elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim); + if (has_output_mask) { + out_mask += elem_to_loc( + tid.z, batch_shape, mask_batch_strides, params->batch_ndim); + + mask_batch_strides += params->batch_ndim; + } if (has_operand_mask) { - const constant size_t* mask_strides_lhs = - mask_batch_strides + params->batch_ndim; + const constant size_t* mask_strides_lhs = mask_batch_strides; const constant size_t* mask_strides_rhs = mask_strides_lhs + params->batch_ndim; @@ -86,10 +134,14 @@ block_masked_gemm( rhs_mask += batch_offsets.y; } } else { - out_mask += tid.z * batch_strides[2 * params->batch_ndim]; + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += params->batch_ndim; + } + if (has_operand_mask) { - lhs_mask += tid.z * batch_strides[3 * params->batch_ndim]; - rhs_mask += tid.z * batch_strides[4 * params->batch_ndim]; + lhs_mask += tid.z * mask_batch_strides[0]; + rhs_mask += tid.z * mask_batch_strides[params->batch_ndim]; } } @@ -121,44 +173,69 @@ block_masked_gemm( B += transpose_b ? c_col_long * params->ldb : c_col_long; D += c_row_long * params->ldd + c_col_long; - bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]]; + const constant int* out_mask_strides = mask_strides; + const constant int* lhs_mask_strides = + mask_strides + (has_output_mask ? 2 : 0); + const constant int* rhs_mask_strides = + lhs_mask_strides + (has_operand_mask ? 2 : 0); - // Write zeros and return - if (!mask_out) { - constexpr short tgp_size = WM * WN * 32; - constexpr short vec_size = 4; + const int out_mask_offset = !has_output_mask + ? 0 + : tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0]; + int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1]; + int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0]; + const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0]; + const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1]; + short k_factor_cnt = k_mask_factor; - // Tile threads in threadgroup - constexpr short TN = BN / vec_size; - constexpr short TM = tgp_size / TN; + ScaleOp out_mask_op; + ScaleOp lhs_mask_op; + ScaleOp rhs_mask_op; - const short thread_idx = simd_group_id * 32 + simd_lane_id; - const short bi = thread_idx / TN; - const short bj = vec_size * (thread_idx % TN); + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; - D += bi * params->ldd + bj; - - short tgp_bm = min(BM, params->M - c_row); - short tgp_bn = min(BN, params->N - c_col); - - if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { - for (short ti = 0; ti < BM; ti += TM) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - D[ti * params->ldd + j] = T(0.); - } - } - } else { - short jmax = tgp_bn - bj; - jmax = jmax < vec_size ? jmax : vec_size; - for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { - for (short j = 0; j < jmax; j++) { - D[ti * params->ldd + j] = T(0.); - } - } + if (has_mul_output_mask) { + out_mask_op.scale = float(mask_out); } - return; + // Write zeros and return + if (!mask_out) { + constexpr short tgp_size = WM * WN * 32; + constexpr short vec_size = 4; + + // Tile threads in threadgroup + constexpr short TN = BN / vec_size; + constexpr short TM = tgp_size / TN; + + const short thread_idx = simd_group_id * 32 + simd_lane_id; + const short bi = thread_idx / TN; + const short bj = vec_size * (thread_idx % TN); + + D += bi * params->ldd + bj; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + for (short ti = 0; ti < BM; ti += TM) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } else { + short jmax = tgp_bn - bj; + jmax = jmax < vec_size ? jmax : vec_size; + for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { + for (short j = 0; j < jmax; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } + + return; + } } threadgroup_barrier(mem_flags::mem_none); @@ -166,8 +243,6 @@ block_masked_gemm( // Prepare threadgroup mma operation thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); - int gemm_k_iterations = params->gemm_k_iterations_aligned; - threadgroup T As[gemm_kernel::tgp_mem_size_a]; threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; @@ -177,21 +252,88 @@ block_masked_gemm( thread typename gemm_kernel::loader_b_t loader_b( B, params->ldb, Bs, simd_group_id, simd_lane_id); + // Prepare threadgroup bounds + const short tgp_bm = + MN_aligned ? short(BM) : short(min(BM, params->M - c_row)); + const short tgp_bn = + MN_aligned ? short(BN) : short(min(BN, params->N - c_col)); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // Do unaligned K iterations first + if (!K_aligned) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int mask_idx_last = k_last / BM; + + if (!has_operand_mask || + (bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) && + bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) { + if (has_mul_operand_mask) { + lhs_mask_op.scale = + lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]; + rhs_mask_op.scale = + rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]; + } + + // Move loader source ahead to end + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + if (has_mul_operand_mask) { + loader_a.apply_inplace_op(lhs_mask_op); + loader_b.apply_inplace_op(rhs_mask_op); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + } + /////////////////////////////////////////////////////////////////////////////// // MNK aligned loop if (MN_aligned) { - for (int k = 0; k < gemm_k_iterations; k++) { + for (; gemm_k_iterations > 0; gemm_k_iterations--) { threadgroup_barrier(mem_flags::mem_threadgroup); if (!has_operand_mask || - (lhs_mask - [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && - rhs_mask - [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + (bool(lhs_mask[lhs_mask_offset]) && + bool(rhs_mask[rhs_mask_offset]))) { + if (has_mul_operand_mask) { + lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; + rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; + } + // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_unsafe(); + if (has_mul_operand_mask) { + loader_a.apply_inplace_op(lhs_mask_op); + loader_b.apply_inplace_op(rhs_mask_op); + } + threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements @@ -201,29 +343,15 @@ block_masked_gemm( // Prepare for next iteration loader_a.next(); loader_b.next(); + + k_factor_cnt--; + lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; + rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; + k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; } - threadgroup_barrier(mem_flags::mem_none); - - // Loop tail - if (!K_aligned) { - if (!has_operand_mask || - (lhs_mask - [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && - rhs_mask - [(params->K / BM) * mask_strides[5] + - tid_x * mask_strides[4]])) { - int lbk = params->K - params->gemm_k_iterations_aligned * BK; - short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); - short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); - - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(As, Bs); - } + if (has_mul_output_mask) { + mma_op.apply_epilogue(out_mask_op); } // Store results to device memory @@ -233,24 +361,25 @@ block_masked_gemm( } /////////////////////////////////////////////////////////////////////////////// // MN unaligned loop - else { // Loop over K - unaligned case - short tgp_bm = min(BM, params->M - c_row); - short tgp_bn = min(BN, params->N - c_col); - short lbk = params->K - params->gemm_k_iterations_aligned * BK; + else { + const bool M_aligned = (tgp_bm == BM); + const bool N_aligned = (tgp_bn == BN); - bool M_aligned = (tgp_bm == BM); - bool N_aligned = (tgp_bn == BN); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); - short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); - short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); - - for (int k = 0; k < gemm_k_iterations; k++) { + for (; gemm_k_iterations > 0; gemm_k_iterations--) { threadgroup_barrier(mem_flags::mem_threadgroup); if (!has_operand_mask || - (lhs_mask - [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && - rhs_mask - [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + (bool(lhs_mask[lhs_mask_offset]) && + bool(rhs_mask[rhs_mask_offset]))) { + if (has_mul_operand_mask) { + lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; + rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; + } + // Load elements into threadgroup if (M_aligned) { loader_a.load_unsafe(); @@ -264,6 +393,11 @@ block_masked_gemm( loader_b.load_safe(tile_dims_B); } + if (has_mul_operand_mask) { + loader_a.apply_inplace_op(lhs_mask_op); + loader_b.apply_inplace_op(rhs_mask_op); + } + threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements @@ -273,29 +407,15 @@ block_masked_gemm( // Prepare for next iteration loader_a.next(); loader_b.next(); + + k_factor_cnt--; + lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; + rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; + k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; } - if (!K_aligned) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (!has_operand_mask || - (lhs_mask - [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && - rhs_mask - [(params->K / BM) * mask_strides[5] + - tid_x * mask_strides[4]])) { - short2 tile_dims_A_last = - transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); - short2 tile_dims_B_last = - transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); - - loader_a.load_safe(tile_dims_A_last); - loader_b.load_safe(tile_dims_B_last); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(As, Bs); - } + if (has_mul_output_mask) { + mma_op.apply_epilogue(out_mask_op); } if (M_aligned && N_aligned) { @@ -311,6 +431,10 @@ block_masked_gemm( /////////////////////////////////////////////////////////////////////////////// #define instantiate_gemm( \ + outmaskname, \ + outmasktype, \ + opmaskname, \ + opmasktype, \ tname, \ trans_a, \ trans_b, \ @@ -326,15 +450,15 @@ block_masked_gemm( aname, \ mn_aligned, \ kname, \ - k_aligned, \ - omname, \ - op_mask) \ - template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname \ + k_aligned) \ + template [[host_name("steel_gemm_block_outmask_" #outmaskname \ + "_opmask_" #opmaskname "_" #tname "_" #iname "_" #oname \ "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \ - "_MN_" #aname "_K_" #kname \ - "_op_mask_" #omname)]] [[kernel]] void \ + "_MN_" #aname "_K_" #kname)]] [[kernel]] void \ block_masked_gemm< \ itype, \ + outmasktype, \ + opmasktype, \ bm, \ bn, \ bk, \ @@ -343,17 +467,16 @@ block_masked_gemm( trans_a, \ trans_b, \ mn_aligned, \ - k_aligned, \ - op_mask>( \ + k_aligned>( \ const device itype* A [[buffer(0)]], \ const device itype* B [[buffer(1)]], \ device itype* D [[buffer(3)]], \ const constant GEMMParams* params [[buffer(4)]], \ const constant int* batch_shape [[buffer(6)]], \ const constant size_t* batch_strides [[buffer(7)]], \ - const device bool* out_mask [[buffer(10)]], \ - const device bool* lhs_mask [[buffer(11)]], \ - const device bool* rhs_mask [[buffer(12)]], \ + const device outmasktype* out_mask [[buffer(10)]], \ + const device opmasktype* lhs_mask [[buffer(11)]], \ + const device opmasktype* rhs_mask [[buffer(12)]], \ const constant int* mask_strides [[buffer(13)]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \ @@ -361,9 +484,15 @@ block_masked_gemm( uint3 lid [[thread_position_in_threadgroup]]); // clang-format off -#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true) // clang-format on +#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + instantiate_gemm(iname, itype, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + instantiate_gemm(bool_, bool, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + instantiate_gemm(iname, itype, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + instantiate_gemm(nomask, nomask_t, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + instantiate_gemm(nomask, nomask_t, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + instantiate_gemm(bool_, bool, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + instantiate_gemm(iname, itype, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) // clang-format on // clang-format off #define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ diff --git a/mlx/backend/metal/kernels/steel/gemm/loader.h b/mlx/backend/metal/kernels/steel/gemm/loader.h index 743dc7a55..aa6e8107d 100644 --- a/mlx/backend/metal/kernels/steel/gemm/loader.h +++ b/mlx/backend/metal/kernels/steel/gemm/loader.h @@ -58,6 +58,18 @@ struct BlockLoader { dst(dst_ + bi * dst_ld + bj), src(src_ + bi * src_ld + bj) {} + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); + } + } + } + /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { STEEL_PRAGMA_UNROLL diff --git a/mlx/backend/metal/kernels/steel/gemm/mma.h b/mlx/backend/metal/kernels/steel/gemm/mma.h index 0fab5b0b2..8214ad723 100644 --- a/mlx/backend/metal/kernels/steel/gemm/mma.h +++ b/mlx/backend/metal/kernels/steel/gemm/mma.h @@ -198,6 +198,24 @@ struct BlockMMA { } } + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + + // Apply epilogue + accum[0] = epilogue_op.apply(accum[0]); + accum[1] = epilogue_op.apply(accum[1]); + } + } + } + /* Apply epilogue */ template METAL_FUNC void apply_epilogue( diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 4de064905..c9273532e 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1307,7 +1307,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { // Check and collapse batch dimensions bool has_op_mask = inputs.size() > 3; - auto& out_mask = inputs[2]; + bool has_out_mask = inputs.size() == 3 || inputs.size() == 5; std::vector batch_shape{1}; size_t A_batch_str = 0; @@ -1350,14 +1350,17 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { int wm = 2, wn = 2; // Prepare kernel name + std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask"; + std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask"; + std::ostringstream kname; - kname << "steel_block_masked_gemm_" << (transpose_a ? 't' : 'n') + kname << "steel_gemm_block_outmask_" << out_mask_nm << "_opmask_" + << op_mask_nm << "_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_" << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" - << ((K % bk == 0) ? "t" : "n") << "aligned" << "_op_mask_" - << (has_op_mask ? "T" : "N"); + << ((K % bk == 0) ? "t" : "n") << "aligned"; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -1397,17 +1400,23 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); std::vector mask_strides; - mask_strides.push_back(*(out_mask.strides().end() - 1)); - mask_strides.push_back(*(out_mask.strides().end() - 2)); + + if (has_out_mask) { + auto& out_mask = inputs[2]; + mask_strides.push_back(*(out_mask.strides().end() - 1)); + mask_strides.push_back(*(out_mask.strides().end() - 2)); + + compute_encoder.set_input_array(out_mask, 10); + } if (has_op_mask) { - auto& lhs_mask = inputs[3]; + auto& lhs_mask = inputs[2 + has_out_mask]; mask_strides.push_back(*(lhs_mask.strides().end() - 1)); mask_strides.push_back(*(lhs_mask.strides().end() - 2)); compute_encoder.set_input_array(lhs_mask, 11); - auto& rhs_mask = inputs[4]; + auto& rhs_mask = inputs[3 + has_out_mask]; mask_strides.push_back(*(rhs_mask.strides().end() - 1)); mask_strides.push_back(*(rhs_mask.strides().end() - 2)); @@ -1424,7 +1433,6 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { set_vector_bytes(compute_encoder, batch_shape, 6); set_vector_bytes(compute_encoder, batch_strides, 7); - compute_encoder.set_input_array(out_mask, 10); set_vector_bytes(compute_encoder, mask_strides, 13); compute_encoder.dispatchThreadgroups(grid_dims, group_dims); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ae0c3a258..6606cb5f1 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3870,48 +3870,60 @@ array block_masked_mm( int tn = (N + block_size - 1) / block_size; int tk = (K + block_size - 1) / block_size; + std::vector inputs = {a, b}; + // Broadcast and astype mask auto broadcast_mask = [](array mask, std::vector& bs_shape, int y, int x, + Dtype mask_dtype, StreamOrDevice s) { int nd_bsx = bs_shape.size(); bs_shape[nd_bsx - 2] = y; bs_shape[nd_bsx - 1] = x; - mask = astype(mask, bool_, s); + mask = astype(mask, mask_dtype, s); return broadcast_to(mask, bs_shape, s); }; // Out mask - array mask_out_p = mask_out.value_or(array({true})); - if (in_a_ndim == 1 || in_b_ndim == 1) { - std::vector ex_dims; - if (in_a_ndim == 1) - ex_dims.push_back(-2); - if (in_b_ndim == 1) - ex_dims.push_back(-1); - mask_out_p = expand_dims(mask_out_p, ex_dims, s); - } - mask_out_p = broadcast_mask(mask_out_p, bsx_shape, tm, tn, s); + if (mask_out.has_value()) { + array mask_out_p = mask_out.value_or(array({true})); + if (in_a_ndim == 1 || in_b_ndim == 1) { + std::vector ex_dims; + if (in_a_ndim == 1) + ex_dims.push_back(-2); + if (in_b_ndim == 1) + ex_dims.push_back(-1); + mask_out_p = expand_dims(mask_out_p, ex_dims, s); + } + auto maskout_dtype = mask_out_p.dtype() == bool_ ? bool_ : out_type; + mask_out_p = + broadcast_mask(mask_out_p, bsx_shape, tm, tn, maskout_dtype, s); - std::vector inputs = {a, b, mask_out_p}; + inputs.push_back(mask_out_p); + } // Operand masks if (has_operand_mask) { - // LHS mask + // Pull masks array mask_lhs_p = mask_lhs.value_or(array({true})); + array mask_rhs_p = mask_rhs.value_or(array({true})); + auto mask_dtype = + (mask_lhs_p.dtype() == bool_ && mask_rhs_p.dtype() == bool_) ? bool_ + : out_type; + + // LHS mask if (in_a_ndim == 1) { mask_lhs_p = expand_dims(mask_lhs_p, -2, s); } - mask_lhs_p = broadcast_mask(mask_lhs_p, bsx_shape, tm, tk, s); + mask_lhs_p = broadcast_mask(mask_lhs_p, bsx_shape, tm, tk, mask_dtype, s); // RHS mask - array mask_rhs_p = mask_rhs.value_or(array({true})); if (in_b_ndim == 1) { - mask_rhs_p = expand_dims(mask_lhs_p, -1, s); + mask_rhs_p = expand_dims(mask_rhs_p, -1, s); } - mask_rhs_p = broadcast_mask(mask_rhs_p, bsx_shape, tk, tn, s); + mask_rhs_p = broadcast_mask(mask_rhs_p, bsx_shape, tk, tn, mask_dtype, s); inputs.push_back(mask_lhs_p); inputs.push_back(mask_rhs_p); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 4f43a72d6..1bbc963c6 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3487,42 +3487,251 @@ std::vector BlockMaskedMM::vjp( const std::vector& cotangents, const std::vector& argnums, const std::vector&) { + ///////////////////////////////////////////////////////////////////////////// + // The operation that is done w/o intermediates by the primitive is + // - tm = (M + block_size - 1) // block_size; MP = tm * block_size; + // - tn = (N + block_size - 1) // block_size; NP = tn * block_size; + // - tm = (K + block_size - 1) // block_size; KP = tk * block_size; + // - mask_b <- mask broadcasted to block sizes + // - A_m = A [..., M, K] * mask_b_lhs [..., MP, KP] + // - B_m = B [..., K, N] * mask_b_rhs [..., KP, MP] + // - C = A_m [..., M, K] @ B_m [..., K, N] + // - C_m = C [..., M, N] * mask_b_out [..., MP, NP] + // + // The grads are therefore + // - dC_m = cotan [..., M, N] + // - dmask_b_out = cotan [..., M, N] * C [..., M, N] + // - dC = cotan [..., M, N] * mask_b_out [..., MP, NP] + // - dA_m = dC [..., M, N] @ B_m.T [..., N, K] + // - dB_m = A_m.T [..., K, M] @ dC [..., M, N] + // - dA = dA_m * mask_b_lhs [..., MP, KP] + // - dB = dB_m * mask_b_rhs [..., KP, MP] + // - dmask_b_lhs = dA_m [..., M, K] * A [..., M, K] // need [..., MP, KP] + // - dmask_b_rhs = dB_m [..., K, N] * B [..., K, N] // need [..., KP, NP] + // + // Observations: + // * If dmask_b_lhs is not needed, then dA can be calulated in one go as a + // as a block_masked_mm with mask_b_lhs as the out_mask without needing to + // materialize the intermediate dA_m. Similar for dB. + // * If dmask_b_lhs is needed, we need to materialize dA_m directly and then + // point-wise multiply with A. But the output needs to be padded + std::vector vjps; auto& cotan = cotangents[0]; std::vector reorder(cotan.ndim()); std::iota(reorder.begin(), reorder.end(), 0); std::iter_swap(reorder.end() - 1, reorder.end() - 2); + bool has_op_mask = primals.size() > 3; + bool has_out_mask = primals.size() == 3 || primals.size() == 5; + + const int op_mask_idx = has_out_mask ? 3 : 2; + bool needs_lhs_mask_vjp = has_op_mask; + bool needs_rhs_mask_vjp = has_op_mask; + bool needs_lhs_vjp = false; + bool needs_rhs_vjp = false; + + for (auto arg : argnums) { + needs_lhs_vjp = arg == 0; + needs_rhs_vjp = arg == 1; + needs_lhs_mask_vjp = arg == op_mask_idx; + needs_rhs_mask_vjp = arg == op_mask_idx + 1; + } + + if ((needs_lhs_mask_vjp && primals[op_mask_idx].dtype() == bool_) || + (needs_rhs_mask_vjp && primals[op_mask_idx + 1].dtype() == bool_)) { + throw std::invalid_argument( + "[BlockMaskedMM] Cannot calculate VJP with respect to boolean masks."); + } + + auto expand_mask = [&](array mask, int Y, int X) { + // Exapnd mask + auto mask_reshape = mask.shape(); + mask = expand_dims(mask, {-3, -1}, stream()); + auto mask_shape = mask.shape(); + int mask_ndim = mask_shape.size(); + + // Broadcast mask + mask_shape[mask_ndim - 1] = block_size_; + mask_shape[mask_ndim - 3] = block_size_; + mask = broadcast_to(mask, mask_shape, stream()); + + // Reshape mask to squeeze in braodcasted dims + mask_ndim = mask_reshape.size(); + mask_reshape[mask_ndim - 2] *= block_size_; + mask_reshape[mask_ndim - 1] *= block_size_; + mask = reshape(mask, mask_reshape, stream()); + + // Slice mask + mask_reshape[mask_ndim - 2] = Y; + mask_reshape[mask_ndim - 1] = X; + mask = slice(mask, std::vector(mask_ndim, 0), mask_reshape, stream()); + + return mask; + }; + + array zero = array(0, cotan.dtype()); + + auto multiply_pad_reduce = [&](array p, array q, int align_Y, int align_X) { + // Multiply with cotan + auto r = multiply(p, q, stream()); + + // Pad if needed + if ((align_Y != 0) || (align_X != 0)) { + r = pad(r, {-2, -1}, {0, 0}, {align_Y, align_X}, zero, stream()); + } + + // Reshape + std::vector r_reshape(r.shape().begin(), r.shape().end() - 2); + r_reshape.push_back(r.shape(-2) / block_size_); + r_reshape.push_back(block_size_); + r_reshape.push_back(r.shape(-1) / block_size_); + r_reshape.push_back(block_size_); + r = reshape(r, r_reshape, stream()); + + // Reduce + return sum(r, {-3, -1}, false, stream()); + }; + + // Prepare for padding if needed + int M = cotan.shape(-2); + int N = cotan.shape(-1); + int K = primals[0].shape(-1); + int align_M = (M % block_size_); + int align_N = (N % block_size_); + int align_K = (K % block_size_); + + // Potential intermediates + auto unmasked_lhs_grad = primals[0]; + auto unmasked_rhs_grad = primals[1]; + + bool unmasked_lhs_grad_calculated = false; + bool unmasked_rhs_grad_calculated = false; + for (auto arg : argnums) { if (arg == 0) { // M X N * (K X N).T -> M X K auto b_t = transpose(primals[1], reorder, stream()); - auto out_mask = primals[2]; - auto lhs_mask = - has_op_mask ? std::make_optional(primals[3]) : std::nullopt; + auto out_mask = + has_out_mask ? std::make_optional(primals[2]) : std::nullopt; + auto lhs_mask = has_op_mask && !needs_lhs_mask_vjp + ? std::make_optional(primals[op_mask_idx]) + : std::nullopt; auto rhs_mask_t = has_op_mask - ? std::make_optional(transpose(primals[4], reorder, stream())) + ? std::make_optional( + transpose(primals[op_mask_idx + 1], reorder, stream())) : std::nullopt; auto grad = block_masked_mm( cotan, b_t, block_size_, lhs_mask, out_mask, rhs_mask_t, stream()); + if (needs_lhs_mask_vjp) { + unmasked_lhs_grad = grad; + unmasked_lhs_grad_calculated = true; + auto exp_mask = expand_mask(primals[op_mask_idx], M, K); + grad = multiply(grad, exp_mask, stream()); + } + vjps.push_back(grad); } else if (arg == 1) { // (M X K).T * M X N -> K X N auto a_t = transpose(primals[0], reorder, stream()); - auto out_mask = primals[2]; + auto out_mask = + has_out_mask ? std::make_optional(primals[2]) : std::nullopt; auto lhs_mask_t = has_op_mask - ? std::make_optional(transpose(primals[3], reorder, stream())) + ? std::make_optional( + transpose(primals[op_mask_idx], reorder, stream())) + : std::nullopt; + auto rhs_mask = has_op_mask && !needs_rhs_mask_vjp + ? std::make_optional(primals[op_mask_idx + 1]) : std::nullopt; - auto rhs_mask = - has_op_mask ? std::make_optional(primals[4]) : std::nullopt; auto grad = block_masked_mm( a_t, cotan, block_size_, rhs_mask, lhs_mask_t, out_mask, stream()); + if (needs_rhs_mask_vjp) { + unmasked_rhs_grad = grad; + unmasked_rhs_grad_calculated = true; + auto exp_mask = expand_mask(primals[op_mask_idx + 1], K, N); + grad = multiply(grad, exp_mask, stream()); + } + vjps.push_back(grad); + + } else if (arg == 2 && has_out_mask) { + // Produce the forward result + auto lhs_mask = has_op_mask + ? std::make_optional(primals[op_mask_idx]) + : std::nullopt; + auto rhs_mask = has_op_mask + ? std::make_optional(primals[op_mask_idx + 1]) + : std::nullopt; + + auto C = block_masked_mm( + primals[0], + primals[1], + block_size_, + primals[2], + lhs_mask, + rhs_mask, + stream()); + + // Multiply, Pad and Reduce if needed + auto grad = multiply_pad_reduce(cotan, C, align_M, align_N); + vjps.push_back(grad); + + } else if (arg == op_mask_idx && has_op_mask) { + if (!unmasked_lhs_grad_calculated) { + // (M X K).T * M X N -> K X N + auto b_t = transpose(primals[1], reorder, stream()); + auto out_mask = + has_out_mask ? std::make_optional(primals[2]) : std::nullopt; + auto rhs_mask_t = + transpose(primals[op_mask_idx + 1], reorder, stream()); + + unmasked_lhs_grad = block_masked_mm( + cotan, + b_t, + block_size_, + std::nullopt, + out_mask, + rhs_mask_t, + stream()); + + unmasked_lhs_grad_calculated = true; + } + + // Multiply, Pad and Reduce if needed + auto grad = + multiply_pad_reduce(primals[0], unmasked_lhs_grad, align_M, align_K); + vjps.push_back(grad); + + } else if (arg == op_mask_idx + 1 && has_op_mask) { + if (!unmasked_rhs_grad_calculated) { + // (M X K).T * M X N -> K X N + auto a_t = transpose(primals[0], reorder, stream()); + auto out_mask = + has_out_mask ? std::make_optional(primals[2]) : std::nullopt; + auto lhs_mask_t = transpose(primals[op_mask_idx], reorder, stream()); + + unmasked_rhs_grad = block_masked_mm( + a_t, + cotan, + block_size_, + std::nullopt, + lhs_mask_t, + out_mask, + stream()); + + unmasked_rhs_grad_calculated = true; + } + + // Multiply, Pad and Reduce if needed + auto grad = + multiply_pad_reduce(primals[1], unmasked_rhs_grad, align_K, align_N); + vjps.push_back(grad); + } else { throw std::invalid_argument( "[BlockMaskedMM] Cannot calculate VJP with respect to masks."); diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 52aef4868..c5ae5eaf5 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -682,7 +682,7 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertEqual(c.shape, (0, 0)) def test_block_masked_matmul(self): - def np_block_masked_mm( + def ref_block_masked_mm( a, b, block_size, out_mask=None, lhs_mask=None, rhs_mask=None ): # Get mask adjusted shapes @@ -690,33 +690,81 @@ class TestBlas(mlx_tests.MLXTestCase): N = b.shape[-1] K = a.shape[-1] + bsx_shape = np.broadcast_shapes(a.shape[:-2], b.shape[:-2]) + # Expand mask dims def expand_mask(mask, block_size, Y, X): - mask = np.expand_dims(mask, (-3, -1)) - mask_shape = list(mask.shape) + mask = mx.expand_dims(mask, (-3, -1)) + mask_shape = list(bsx_shape) + list(mask.shape[-4:]) mask_shape[-1] = block_size x = mask_shape[-2] * block_size mask_shape[-3] = block_size y = mask_shape[-4] * block_size - mask = np.broadcast_to(mask, mask_shape) + mask = mx.broadcast_to(mask, mask_shape) mask_shape = mask_shape[:-4] + [y, x] return mask.reshape(mask_shape)[..., :Y, :X] + a_masked = a + b_masked = b + if lhs_mask is not None: - lhs_mask = expand_mask(lhs_mask, block_size, M, K) - a = lhs_mask * a + lhs_mask = expand_mask(lhs_mask, block_size, M, K).astype(mx.float32) + a_masked = lhs_mask * a_masked if rhs_mask is not None: - rhs_mask = expand_mask(rhs_mask, block_size, K, N) - b = rhs_mask * b + rhs_mask = expand_mask(rhs_mask, block_size, K, N).astype(mx.float32) + b_masked = rhs_mask * b_masked - out = a @ b + out = a_masked @ b_masked if out_mask is not None: - out_mask = expand_mask(out_mask, block_size, M, N) + out_mask = expand_mask(out_mask, block_size, M, N).astype(mx.float32) out = out * out_mask return out + def run_test(a, b, block_size, out_mask, a_mask, b_mask, cotan): + def f_ref(a_, b_): + return ref_block_masked_mm(a_, b_, block_size, out_mask, a_mask, b_mask) + + def f_test(a_, b_): + return mx.block_masked_mm(a_, b_, block_size, out_mask, a_mask, b_mask) + + out_ref, dout_ref = mx.vjp(f_ref, [a, b], [cotan]) + out_test, dout_test = mx.vjp(f_test, [a, b], [cotan]) + + mx.eval((out_ref, dout_ref, out_test, dout_test)) + + self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item()) + + def run_test_mask_vjp(a, b, block_size, out_mask, a_mask, b_mask, cotan): + def f_ref(a_, b_, a_mask_, b_mask_): + return ref_block_masked_mm( + a_, b_, block_size, out_mask, a_mask_, b_mask_ + ) + + def f_test(a_, b_, a_mask_, b_mask_): + return mx.block_masked_mm( + a_, b_, block_size, out_mask, a_mask_, b_mask_ + ) + + out_ref, dout_ref = mx.vjp(f_ref, [a, b, a_mask, b_mask], [cotan]) + out_test, dout_test = mx.vjp(f_test, [a, b, a_mask, b_mask], [cotan]) + + mx.eval((out_ref, dout_ref, out_test, dout_test)) + + self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item()) + + for r, t in zip(dout_ref, dout_test): + self.assertEqual(r.shape, t.shape) + self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) + + def make_mask(tm_, tn_, batch, np_dtype): + arr_np_mask = np.random.normal(size=batch + (tm_, tn_)).astype(np_dtype) + arr_np_bool_mask = arr_np_mask < 0.0 + arr_np_mask[arr_np_bool_mask] = 0.0 + + return mx.array(arr_np_bool_mask), mx.array(arr_np_mask) + def test_shape( M, N, @@ -737,49 +785,49 @@ class TestBlas(mlx_tests.MLXTestCase): batch_A=batch_A, batch_B=batch_B, ): - tm = (M + block_size - 1) // block_size - tn = (N + block_size - 1) // block_size - tk = (K + block_size - 1) // block_size + batch_out = np.broadcast_shapes(batch_A, batch_B) + cotan = mx.ones(batch_out + (M, N)) a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype) b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype) - batch_out = np.broadcast_shapes(batch_A, batch_B) + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) - a_np_mask = np.random.normal(size=batch_A + (tm, tk)) < 0.0 - b_np_mask = np.random.normal(size=batch_B + (tk, tn)) < 0.0 - out_np_mask = np.random.normal(size=batch_out + (tm, tn)) < 0.0 + tm = (M + block_size - 1) // block_size + tn = (N + block_size - 1) // block_size + tk = (K + block_size - 1) // block_size - a_mx, b_mx, a_mx_mask, b_mx_mask, out_mx_mask = map( - mx.array, (a_np, b_np, a_np_mask, b_np_mask, out_np_mask) + a_mx_bool_mask, a_mx_mask = make_mask(tm, tk, batch_A, np_dtype) + b_mx_bool_mask, b_mx_mask = make_mask(tk, tn, batch_B, np_dtype) + out_mx_bool_mask, out_mx_mask = make_mask(tm, tn, batch_out, np_dtype) + + # Boolean block masks + run_test( + a_mx, + b_mx, + block_size, + out_mx_bool_mask, + a_mx_bool_mask, + b_mx_bool_mask, + cotan, + ) + run_test(a_mx, b_mx, block_size, out_mx_bool_mask, None, None, cotan) + run_test( + a_mx, b_mx, block_size, None, a_mx_bool_mask, b_mx_bool_mask, cotan ) - if transpose: - b_np = np.random.normal(size=batch_B + (N, K)).astype(np_dtype) - b_mx = mx.array(b_np) - - b_np = np.swapaxes(b_np, -2, -1) - b_mx = mx.swapaxes(b_mx, -2, -1) - - out_np = np_block_masked_mm( - a_np, b_np, block_size, out_np_mask, a_np_mask, b_np_mask + # Float block masks + run_test( + a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask, cotan ) - out_mx = mx.block_masked_mm( - a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask + run_test(a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask, cotan) + run_test_mask_vjp( + a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask, cotan ) - self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) - - out_np = np_block_masked_mm(a_np, b_np, block_size, out_np_mask) - out_mx = mx.block_masked_mm(a_mx, b_mx, block_size, out_mx_mask) - self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) - - out_np = np_block_masked_mm( - a_np, b_np, block_size, None, a_np_mask, b_np_mask + run_test_mask_vjp( + a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask, cotan ) - out_mx = mx.block_masked_mm( - a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask - ) - self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) shapes = ( (16, 16, 16, 32), @@ -789,11 +837,10 @@ class TestBlas(mlx_tests.MLXTestCase): ) for M, N, K, block_size in shapes: - test_shape(M, N, K, block_size, transpose=False) - test_shape(M, N, K, block_size, transpose=True) + test_shape(M, N, K, block_size) # Test broadcasting - test_shape(64, 64, 64, 32, transpose=False, batch_A=(1, 2), batch_B=(2, 2)) + test_shape(64, 64, 64, 32, batch_A=(1, 2), batch_B=(2, 2)) # Test gemv a_np = np.random.normal(size=(64, 64)).astype(np.float32)