diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 1e934befd..a890e7f3b 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -29,6 +29,7 @@ Operations atleast_2d atleast_3d broadcast_to + block_masked_mm ceil clip concatenate diff --git a/mlx/backend/accelerate/matmul.cpp b/mlx/backend/accelerate/matmul.cpp index b26a16bab..6113223a4 100644 --- a/mlx/backend/accelerate/matmul.cpp +++ b/mlx/backend/accelerate/matmul.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include @@ -196,6 +196,40 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) { return matmul_bnns_general(a_pre, b_pre, out); } +template +inline void mask_matrix( + T* data, + const bool* mask, + int tile_size, + const int X, + const int Y, + const size_t X_data_str, + const size_t Y_data_str, + const size_t X_mask_str, + const size_t Y_mask_str) { + int tX = (X + tile_size - 1) / tile_size; + int tY = (Y + tile_size - 1) / tile_size; + + for (int i = 0; i < tX; i++) { + for (int j = 0; j < tY; j++) { + bool do_mask = mask[i * X_mask_str + j * Y_mask_str]; + if (!do_mask) { + int loc_x = i * tile_size; + int loc_y = j * tile_size; + T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str; + + int size_x = std::min(tile_size, X - loc_x); + int size_y = std::min(tile_size, Y - loc_y); + for (int ii = 0; ii < size_x; ii++) { + for (int jj = 0; jj < size_y; jj++) { + data_block[ii * X_data_str + jj * Y_data_str] = T(0.); + } + } + } + } + } +} + } // namespace void Matmul::eval_cpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index d8fa52b19..7a3610717 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -31,6 +31,7 @@ DEFAULT(ArgPartition) DEFAULT(ArgReduce) DEFAULT(ArgSort) DEFAULT(AsStrided) +DEFAULT(BlockMaskedMM) DEFAULT(Broadcast) DEFAULT(Ceil) DEFAULT(Concatenate) diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index a872d5b99..ea0babf18 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -41,6 +41,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 219d52ad3..a0cc42984 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -41,6 +41,7 @@ DEFAULT(ArgSort) DEFAULT(AsType) DEFAULT(AsStrided) DEFAULT(Broadcast) +DEFAULT(BlockMaskedMM) DEFAULT_MULTI(DivMod) DEFAULT(Ceil) DEFAULT(Concatenate) diff --git a/mlx/backend/common/masked_mm.cpp b/mlx/backend/common/masked_mm.cpp new file mode 100644 index 000000000..52711d512 --- /dev/null +++ b/mlx/backend/common/masked_mm.cpp @@ -0,0 +1,193 @@ +// Copyright © 2024 Apple Inc. + +#ifdef ACCELERATE_NEW_LAPACK +#include +#else +#include +#endif + +#include + +#include "mlx/array.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +template +inline void mask_matrix( + T* data, + const bool* mask, + int block_size, + const int X, + const int Y, + const size_t X_data_str, + const size_t Y_data_str, + const size_t X_mask_str, + const size_t Y_mask_str) { + int tX = (X + block_size - 1) / block_size; + int tY = (Y + block_size - 1) / block_size; + + for (int i = 0; i < tX; i++) { + for (int j = 0; j < tY; j++) { + bool do_mask = mask[i * X_mask_str + j * Y_mask_str]; + if (!do_mask) { + int loc_x = i * block_size; + int loc_y = j * block_size; + T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str; + + int size_x = std::min(block_size, X - loc_x); + int size_y = std::min(block_size, Y - loc_y); + for (int ii = 0; ii < size_x; ii++) { + for (int jj = 0; jj < size_y; jj++) { + data_block[ii * X_data_str + jj * Y_data_str] = T(0.); + } + } + } + } + } +} + +} // namespace + +void BlockMaskedMM::eval(const std::vector& inputs, array& out) { + if (out.dtype() != float32) { + throw std::runtime_error( + "[BlockMaskedMM::eval] Currently only supports float32."); + } + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + auto& out_mask = inputs[2]; + + auto check_transpose = [](const array& arr, bool do_copy) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (stx == arr.shape(-1) && sty == 1) { + if (do_copy) { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::Vector); + return std::make_tuple(false, stx, arr_copy); + } + return std::make_tuple(false, stx, arr); + } else if (stx == 1 && sty == arr.shape(-2)) { + if (do_copy) { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::Vector); + return std::make_tuple(true, sty, arr_copy); + } + return std::make_tuple(true, sty, arr); + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::General); + size_t stx = arr.shape(-1); + return std::make_tuple(false, stx, arr_copy); + } + }; + + bool has_op_mask = inputs.size() > 3; + auto [a_transposed, lda, a] = check_transpose(a_pre, has_op_mask); + auto [b_transposed, ldb, b] = check_transpose(b_pre, has_op_mask); + + size_t M = a.shape(-2); + size_t N = b.shape(-1); + size_t K = a.shape(-1); + + if (M == 0 || N == 0) { + return; + } + + if (K == 0) { + std::memset(static_cast(out.data()), 0, out.nbytes()); + return; + } + + auto mask_array = [](const array& mask, + float* data, + int block_size, + int batch_idx, + int X, + int Y, + size_t X_data_str, + size_t Y_data_str) { + const bool* mask_ptr = mask.data() + + elem_to_loc(mask.shape(-1) * mask.shape(-2) * batch_idx, + mask.shape(), + mask.strides()); + + size_t X_mask_str = mask.strides()[mask.ndim() - 2]; + size_t Y_mask_str = mask.strides()[mask.ndim() - 1]; + + return mask_matrix( + data, + mask_ptr, + block_size, + X, + Y, + X_data_str, + Y_data_str, + X_mask_str, + Y_mask_str); + }; + + for (int i = 0; i < (a.size() / (M * K)); ++i) { + // Adjust pointer + float* ai = + a.data() + elem_to_loc(M * K * i, a.shape(), a.strides()); + float* bi = + b.data() + elem_to_loc(K * N * i, b.shape(), b.strides()); + float* ci = out.data() + M * N * i; + + // Zero out blocks in a and b if needed + if (has_op_mask) { + auto& a_mask = inputs[3]; + mask_array( + a_mask, + ai, + block_size_, + i, + M, + K, + a_transposed ? 1 : lda, + a_transposed ? lda : 1); + + auto& b_mask = inputs[4]; + mask_array( + b_mask, + bi, + block_size_, + i, + K, + N, + b_transposed ? 1 : ldb, + b_transposed ? ldb : 1); + } + + // Do matmul + cblas_sgemm( + CblasRowMajor, + a_transposed ? CblasTrans : CblasNoTrans, // transA + b_transposed ? CblasTrans : CblasNoTrans, // transB + M, + N, + K, + 1.0, // alpha + ai, + lda, + bi, + ldb, + 0.0, // beta + ci, + out.shape(-1) // ldc + ); + + // Zero out blocks in out + mask_array(out_mask, ci, block_size_, i, M, N, N, 1); + } +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal new file mode 100644 index 000000000..78ef0f212 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal @@ -0,0 +1,323 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" + +using namespace metal; +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void block_masked_gemm( + const device T *A [[buffer(0)]], + const device T *B [[buffer(1)]], + device T *D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + const device bool *out_mask [[buffer(10)]], + const device bool *lhs_mask [[buffer(11)]], + const device bool *rhs_mask [[buffer(12)]], + const constant int* mask_strides [[buffer(13)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + // Appease the compiler + (void)lid; + + using gemm_kernel = GEMMKernel; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + if(params->batch_ndim > 1) { + const constant size_t* mask_batch_strides = batch_strides + 2 * params->batch_ndim; + out_mask += elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim); + + if(has_operand_mask) { + const constant size_t* mask_strides_lhs = mask_batch_strides + params->batch_ndim; + const constant size_t* mask_strides_rhs = mask_strides_lhs + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, mask_strides_lhs, mask_strides_rhs, params->batch_ndim); + + lhs_mask += batch_offsets.x; + rhs_mask += batch_offsets.y; + } + } + + // Adjust for batch + if(params->batch_ndim > 1) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + } + + D += params->batch_stride_d * tid.z; + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + + A += transpose_a ? c_row : c_row * params->lda; + B += transpose_b ? c_col * params->ldb : c_col; + D += c_row * params->ldd + c_col; + + + bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]]; + + // Write zeros and return + if(!mask_out) { + constexpr short tgp_size = WM * WN * 32; + constexpr short vec_size = 4; + + // Tile threads in threadgroup + constexpr short TN = BN / vec_size; + constexpr short TM = tgp_size / TN; + + const short thread_idx = simd_group_id * 32 + simd_lane_id; + const short bi = thread_idx / TN; + const short bj = vec_size * (thread_idx % TN); + + D += bi * params->ldd + bj; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + for (short ti = 0; ti < BM; ti += TM) { + STEEL_PRAGMA_UNROLL + for(short j = 0; j < vec_size; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } else { + short jmax = tgp_bn - bj; + jmax = jmax < vec_size ? jmax : vec_size; + for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { + for(short j = 0; j < jmax; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } + + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Prepare threadgroup loading operations + thread typename gemm_kernel::loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread typename gemm_kernel::loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if(!has_operand_mask || + (lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && + rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + + if(!has_operand_mask || + (lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && + rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + + } + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short lbk = params->K - params->gemm_k_iterations_aligned * BK; + + bool M_aligned = (tgp_bm == BM); + bool N_aligned = (tgp_bn == BN); + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if(!has_operand_mask || + (lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && + rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if(!has_operand_mask || + (lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && + rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + + } + } + + if(M_aligned && N_aligned) { + mma_op.store_result(D, params->ldd); + } else { + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel initializations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, omname, op_mask) \ + template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_op_mask_" #omname)]] \ + [[kernel]] void block_masked_gemm( \ + const device itype *A [[buffer(0)]], \ + const device itype *B [[buffer(1)]], \ + device itype *D [[buffer(3)]], \ + const constant GEMMParams* params [[buffer(4)]], \ + const constant int* batch_shape [[buffer(6)]], \ + const constant size_t* batch_strides [[buffer(7)]], \ + const device bool *out_mask [[buffer(10)]], \ + const device bool *lhs_mask [[buffer(11)]], \ + const device bool *rhs_mask [[buffer(12)]], \ + const constant int* mask_strides [[buffer(13)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true) + +#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ + instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ + instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ + instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) + +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) + +instantiate_gemm_shapes_helper(float16, half, float16, half); +instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); +instantiate_gemm_shapes_helper(float32, float, float32, float); \ No newline at end of file diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 856f4d67e..1019f6737 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1064,4 +1064,181 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { return; } +void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { + using namespace mlx::steel; + // assert(inputs.size() == 2); + if (!issubdtype(out.dtype(), floating)) { + throw std::runtime_error( + "[matmul] Does not yet support non-floating point types."); + } + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + // Return 0s if either input is empty + if (a_pre.size() == 0 || b_pre.size() == 0) { + array zero = array(0, a_pre.dtype()); + copy_gpu(zero, out, CopyType::Scalar, s); + auto command_buffer = d.get_command_buffer(s.index); + command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {}); + return; + } + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + ///////////////////////////////////////////////////////////////////////////// + // Init checks and prep + + // Keep a vector with copies to be cleared in the completed buffer to release + // the arrays + std::vector copies; + auto check_transpose = [&copies, &s](const array& arr) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (sty == 1) { + return std::make_tuple(false, stx, arr); + } else if (stx == 1) { + return std::make_tuple(true, sty, arr); + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + size_t stx = arr.shape(-1); + return std::make_tuple(false, stx, arr_copy); + } + }; + + auto [transpose_a, a_cols, a] = check_transpose(a_pre); + auto [transpose_b, b_cols, b] = check_transpose(b_pre); + + int lda = a_cols; + int ldb = b_cols; + + int M = a.shape(-2); + int N = b.shape(-1); + int K = a.shape(-1); + + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + + auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b); + + auto batch_size_out = out.size() / (M * N); + int matrix_stride_out = M * N; + + ///////////////////////////////////////////////////////////////////////////// + // Regular kernel dispatch + + // Determine dispatch kernel + int bm = block_size_, bn = block_size_, bk = 16; + int wm = 2, wn = 2; + + // Prepare kernel name + std::ostringstream kname; + kname << "steel_block_masked_gemm_" << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" + << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn << "_MN_" + << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" + << ((K % bk == 0) ? "t" : "n") << "aligned" << "_op_mask_" + << (inputs.size() > 3 ? "T" : "N"); + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + // Use problem size to determine threadblock swizzle + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + // TODO: Explore device-based tuning for swizzle + int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); + + // Prepare steel matmul params + GEMMParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ ldb, + /* const int ldd = */ N, + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const int batch_stride_a = */ int(A_batch_stride.back()), + /* const int batch_stride_b = */ int(B_batch_stride.back()), + /* const int batch_stride_d = */ matrix_stride_out, + /* const int swizzle_log = */ swizzle_log, + /* const int gemm_k_iterations_aligned = */ (K / bk), + /* const int batch_ndim = */ int(batch_shape.size())}; + + // Prepare launch grid params + int tile = 1 << swizzle_log; + tm = (tm + tile - 1) / tile; + tn = tn * tile; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); + + std::vector batch_strides = A_batch_stride; + batch_strides.insert( + batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); + + std::vector mask_strides; + + auto& out_mask = inputs[2]; + mask_strides.push_back(*(out_mask.strides().end() - 1)); + mask_strides.push_back(*(out_mask.strides().end() - 2)); + + batch_strides.insert( + batch_strides.end(), + out_mask.strides().begin(), + out_mask.strides().end() - 2); + + if (inputs.size() > 3) { + auto& lhs_mask = inputs[3]; + mask_strides.push_back(*(lhs_mask.strides().end() - 1)); + mask_strides.push_back(*(lhs_mask.strides().end() - 2)); + + batch_strides.insert( + batch_strides.end(), + lhs_mask.strides().begin(), + lhs_mask.strides().end() - 2); + + compute_encoder.set_input_array(lhs_mask, 11); + + auto& rhs_mask = inputs[4]; + mask_strides.push_back(*(rhs_mask.strides().end() - 1)); + mask_strides.push_back(*(rhs_mask.strides().end() - 2)); + + batch_strides.insert( + batch_strides.end(), + rhs_mask.strides().begin(), + rhs_mask.strides().end() - 2); + + compute_encoder.set_input_array(rhs_mask, 12); + } + + // Launch kernel + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(out, 3); + + compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4); + + set_vector_bytes(compute_encoder, batch_shape, 6); + set_vector_bytes(compute_encoder, batch_strides, 7); + + compute_encoder.set_input_array(out_mask, 10); + set_vector_bytes(compute_encoder, mask_strides, 13); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + + // Clear copies + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + return; +} + } // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 211ccbf9d..4891415a3 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -99,6 +99,7 @@ NO_GPU(Subtract) NO_GPU_MULTI(SVD) NO_GPU(Tan) NO_GPU(Tanh) +NO_GPU(BlockMaskedMM) NO_GPU(Transpose) NO_GPU(Inverse) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 80cd40127..ed92831a7 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3572,6 +3572,172 @@ array addmm( return out; } +/** Compute matrix product with tile-level masking */ +array block_masked_mm( + array a, + array b, + int block_size, + std::optional mask_out /* = std::nullopt */, + std::optional mask_lhs /* = std::nullopt */, + std::optional mask_rhs /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + // If no masks, just perform regular matmul + if (!mask_out && !mask_lhs && !mask_rhs) { + return matmul(a, b, s); + } + + bool has_out_mask = mask_out.has_value(); + bool has_operand_mask = mask_lhs.has_value() || mask_rhs.has_value(); + + // Check valid tile sizes + // TODO: Add support for 16x16 tile + if (block_size != 32 && block_size != 64) { + std::ostringstream msg; + msg << "[block_masked_mm] Only block_sizes 32, 64 are supported." + << "Got block size " << block_size << "."; + throw std::invalid_argument(msg.str()); + } + + // Do shape checks for operands + int in_a_ndim = a.ndim(); + int in_b_ndim = b.ndim(); + + if (a.ndim() == 0 || b.ndim() == 0) { + throw std::invalid_argument( + "[block_masked_mm] Got 0 dimension input. Inputs must " + "have at least one dimension."); + } + + if (a.ndim() == 1) { + // Insert a singleton dim in the beginning + a = reshape(a, {1, -1}, s); + } + if (b.ndim() == 1) { + // Insert a singleton dim at the end + b = reshape(b, {-1, 1}, s); + } + + if (a.shape(-1) != b.shape(-2)) { + std::ostringstream msg; + msg << "[block_masked_mm] Last dimension of first input with shape " + << a.shape() << " must match second to last dimension of" + << " second input with shape " << b.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + // Type promotion + auto out_type = result_type(a, b); + if (!issubdtype(out_type, floating)) { + std::ostringstream msg; + msg << "[block_masked_mm] Only real floating point types are supported but " + << a.dtype() << " and " << b.dtype() + << " were provided which results in " << out_type + << ", which is not a real floating point type."; + throw std::invalid_argument(msg.str()); + } + + a = astype(a, out_type, s); + b = astype(b, out_type, s); + + // Handle broadcasting + std::vector bsx_a(a.shape().begin(), a.shape().end() - 2); + std::vector bsx_b(b.shape().begin(), b.shape().end() - 2); + + auto bsx_shape = broadcast_shapes(bsx_a, bsx_b); + + bsx_shape.push_back(1); + bsx_shape.push_back(1); + int nd = bsx_shape.size(); + + int M = a.shape(-2); + int N = b.shape(-1); + int K = a.shape(-1); + + // Prepare A + bsx_shape[nd - 2] = M; + bsx_shape[nd - 1] = K; + a = broadcast_to(a, bsx_shape, s); + + // Prepare B + bsx_shape[nd - 2] = K; + bsx_shape[nd - 1] = N; + b = broadcast_to(b, bsx_shape, s); + + // Get output shape + auto out_shape = bsx_shape; + out_shape[nd - 2] = M; + out_shape[nd - 1] = N; + + // Determine mask shape requirments + int tm = (M + block_size - 1) / block_size; + int tn = (N + block_size - 1) / block_size; + int tk = (K + block_size - 1) / block_size; + + // Broadcast and astype mask + auto broadcast_mask = [](array mask, + std::vector& bs_shape, + int y, + int x, + StreamOrDevice s) { + int nd_bsx = bs_shape.size(); + bs_shape[nd_bsx - 2] = y; + bs_shape[nd_bsx - 1] = x; + mask = astype(mask, bool_, s); + return broadcast_to(mask, bs_shape, s); + }; + + // Out mask + array mask_out_p = mask_out.value_or(array({true})); + if (in_a_ndim == 1 || in_b_ndim == 1) { + std::vector ex_dims; + if (in_a_ndim == 1) + ex_dims.push_back(-2); + if (in_b_ndim == 1) + ex_dims.push_back(-1); + mask_out_p = expand_dims(mask_out_p, ex_dims, s); + } + mask_out_p = broadcast_mask(mask_out_p, bsx_shape, tm, tn, s); + + std::vector inputs = {a, b, mask_out_p}; + + // Operand masks + if (has_operand_mask) { + // LHS mask + array mask_lhs_p = mask_lhs.value_or(array({true})); + if (in_a_ndim == 1) { + mask_lhs_p = expand_dims(mask_lhs_p, -2, s); + } + mask_lhs_p = broadcast_mask(mask_lhs_p, bsx_shape, tm, tk, s); + + // RHS mask + array mask_rhs_p = mask_rhs.value_or(array({true})); + if (in_b_ndim == 1) { + mask_rhs_p = expand_dims(mask_lhs_p, -1, s); + } + mask_rhs_p = broadcast_mask(mask_rhs_p, bsx_shape, tk, tn, s); + + inputs.push_back(mask_lhs_p); + inputs.push_back(mask_rhs_p); + } + + // Caculate array + auto out = array( + out_shape, + out_type, + std::make_shared(to_stream(s), block_size), + std::move(inputs)); + + // Remove the possibly inserted singleton dimensions + if (in_a_ndim == 1 || in_b_ndim == 1) { + out_shape.erase( + out_shape.end() - ((in_a_ndim == 1) ? 2 : 1), + out_shape.end() - ((in_b_ndim == 1) ? 0 : 1)); + out = reshape(out, out_shape, s); + } + + return out; +} + array diagonal( const array& a, int offset /* = 0 */, diff --git a/mlx/ops.h b/mlx/ops.h index 4b18635e2..f5d8b3f65 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1185,6 +1185,16 @@ array addmm( const float& beta = 1.f, StreamOrDevice s = {}); +/** Compute matrix product with block masking */ +array block_masked_mm( + array a, + array b, + int block_size, + std::optional mask_out = std::nullopt, + std::optional mask_lhs = std::nullopt, + std::optional mask_rhs = std::nullopt, + StreamOrDevice s = {}); + /** Extract a diagonal or construct a diagonal array */ array diagonal( const array& a, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 31b64a965..3cad50422 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3272,6 +3272,59 @@ std::pair, std::vector> Tanh::vmap( return {{tanh(inputs[0], stream())}, axes}; } +std::vector BlockMaskedMM::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + std::vector vjps; + auto& cotan = cotangents[0]; + std::vector reorder(cotan.ndim()); + std::iota(reorder.begin(), reorder.end(), 0); + std::iter_swap(reorder.end() - 1, reorder.end() - 2); + bool has_op_mask = primals.size() > 3; + for (auto arg : argnums) { + if (arg == 0) { + // M X N * (K X N).T -> M X K + auto b_t = transpose(primals[1], reorder, stream()); + auto out_mask = primals[2]; + auto lhs_mask = + has_op_mask ? std::make_optional(primals[3]) : std::nullopt; + auto rhs_mask_t = has_op_mask + ? std::make_optional(transpose(primals[4], reorder, stream())) + : std::nullopt; + + auto grad = block_masked_mm( + cotan, b_t, block_size_, lhs_mask, out_mask, rhs_mask_t, stream()); + + vjps.push_back(grad); + + } else if (arg == 1) { + // (M X K).T * M X N -> K X N + auto a_t = transpose(primals[0], reorder, stream()); + auto out_mask = primals[2]; + auto lhs_mask_t = has_op_mask + ? std::make_optional(transpose(primals[3], reorder, stream())) + : std::nullopt; + auto rhs_mask = + has_op_mask ? std::make_optional(primals[4]) : std::nullopt; + + auto grad = block_masked_mm( + a_t, cotan, block_size_, rhs_mask, lhs_mask_t, out_mask, stream()); + + vjps.push_back(grad); + } else { + vjps.push_back(zeros_like(primals[arg], stream())); + } + } + return vjps; +} + +bool BlockMaskedMM::is_equivalent(const Primitive& other) const { + const BlockMaskedMM& a_other = static_cast(other); + return (block_size_ == a_other.block_size_); +} + std::vector Transpose::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index ff2f01bc1..e90564724 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -443,6 +443,29 @@ class AsStrided : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class BlockMaskedMM : public UnaryPrimitive { + public: + explicit BlockMaskedMM(Stream stream, int block_size) + : UnaryPrimitive(stream), block_size_(block_size){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_PRINT(BlockMaskedMM) + bool is_equivalent(const Primitive& other) const override; + + private: + int block_size_; + + void eval(const std::vector& inputs, array& out); +}; + class Broadcast : public UnaryPrimitive { public: explicit Broadcast(Stream stream, const std::vector& shape) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a4ca29d6e..6a36faff6 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3645,6 +3645,44 @@ void init_ops(nb::module_& m) { Returns: array: ``alpha * (a @ b) + beta * c`` )pbdoc"); + m.def( + "block_masked_mm", + &block_masked_mm, + nb::arg(), + nb::arg(), + "block_size"_a = 64, + "mask_out"_a = nb::none(), + "mask_lhs"_a = nb::none(), + "mask_rhs"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: array, mask_lhs: array, mask_rhs: array, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Matrix multiplication with block masking. + + Perform the (possibly batched) matrix multiplication of two arrays and with blocks + of size ``block_size x block_size`` optionally masked out. + + Assuming ``a`` with shape (..., `M`, `K`) and b with shape (..., `K`, `N`) + + * ``lhs_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `K` / ``block_size`` :math:`\rceil`) + + * ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`) + + * ``out_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`) + + Note: Only ``block_size=64`` and ``block_size=32`` are currently supported + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``) + mask_out (array, optional): Boolean mask for output (default: ``None``) + mask_lhs (array, optional): Boolean mask for a (default: ``None``) + mask_rhs (array, optional): Boolean mask for b (default: ``None``) + + )pbdoc"); m.def( "diagonal", &diagonal, diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 0d3417dc1..ef435e51e 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import math import unittest @@ -681,6 +681,119 @@ class TestBlas(mlx_tests.MLXTestCase): mx.eval(c) self.assertEqual(c.shape, (0, 0)) + def test_block_masked_matmul(self): + def np_block_masked_mm( + a, b, block_size, out_mask=None, lhs_mask=None, rhs_mask=None + ): + # Get mask adjusted shapes + M = a.shape[-2] + N = b.shape[-1] + K = a.shape[-1] + + # Expand mask dims + def expand_mask(mask, block_size, Y, X): + mask = np.expand_dims(mask, (-3, -1)) + mask_shape = list(mask.shape) + mask_shape[-1] = block_size + x = mask_shape[-2] * block_size + mask_shape[-3] = block_size + y = mask_shape[-4] * block_size + mask = np.broadcast_to(mask, mask_shape) + mask_shape = mask_shape[:-4] + [y, x] + return mask.reshape(mask_shape)[..., :Y, :X] + + if lhs_mask is not None: + lhs_mask = expand_mask(lhs_mask, block_size, M, K) + a = lhs_mask * a + + if rhs_mask is not None: + rhs_mask = expand_mask(rhs_mask, block_size, K, N) + b = rhs_mask * b + + out = a @ b + + if out_mask is not None: + out_mask = expand_mask(out_mask, block_size, M, N) + out = out * out_mask + return out + + def test_shape(M, N, K, block_size, transpose=False, np_dtype=np.float32): + with self.subTest( + M=M, + N=N, + K=K, + block_size=block_size, + np_dtype=np_dtype, + transpose=transpose, + ): + tm = (M + block_size - 1) // block_size + tn = (N + block_size - 1) // block_size + tk = (K + block_size - 1) // block_size + + a_np = np.random.normal(size=(M, K)).astype(np_dtype) + b_np = np.random.normal(size=(K, N)).astype(np_dtype) + + a_np_mask = np.random.normal(size=(tm, tk)) < 0.0 + b_np_mask = np.random.normal(size=(tk, tn)) < 0.0 + out_np_mask = np.random.normal(size=(tm, tn)) < 0.0 + + a_mx, b_mx, a_mx_mask, b_mx_mask, out_mx_mask = map( + mx.array, (a_np, b_np, a_np_mask, b_np_mask, out_np_mask) + ) + + if transpose: + b_np = np.random.normal(size=(N, K)).astype(np_dtype) + b_mx = mx.array(b_np) + + b_np = b_np.T + b_mx = b_mx.T + + out_np = np_block_masked_mm( + a_np, b_np, block_size, out_np_mask, a_np_mask, b_np_mask + ) + out_mx = mx.block_masked_mm( + a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask + ) + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) + + out_np = np_block_masked_mm(a_np, b_np, block_size, out_np_mask) + out_mx = mx.block_masked_mm(a_mx, b_mx, block_size, out_mx_mask) + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) + + out_np = np_block_masked_mm( + a_np, b_np, block_size, None, a_np_mask, b_np_mask + ) + out_mx = mx.block_masked_mm( + a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask + ) + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) + + shapes = ( + (16, 16, 16, 32), + (64, 64, 16, 32), + (128, 128, 128, 32), + (256, 256, 128, 64), + ) + + for M, N, K, block_size in shapes: + test_shape(M, N, K, block_size, transpose=False) + test_shape(M, N, K, block_size, transpose=True) + + # Test gemv + a_np = np.random.normal(size=(64, 64)).astype(np.float32) + b_np = np.random.normal(size=(64,)).astype(np.float32) + mask_np = np.array([True, False]).astype(np.bool_) + + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + mask_mx = mx.array(mask_np) + + c_mx = mx.block_masked_mm(a_mx, b_mx, 32, mask_mx) + c_np = a_np @ b_np + c_np[32:] = 0.0 + + self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5)) + if __name__ == "__main__": unittest.main()