mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 04:51:13 +08:00
parent
107ba2891a
commit
b18468bf81
@ -29,6 +29,7 @@ Operations
|
||||
atleast_2d
|
||||
atleast_3d
|
||||
broadcast_to
|
||||
block_masked_mm
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
@ -196,6 +196,40 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
||||
return matmul_bnns_general(a_pre, b_pre, out);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void mask_matrix(
|
||||
T* data,
|
||||
const bool* mask,
|
||||
int tile_size,
|
||||
const int X,
|
||||
const int Y,
|
||||
const size_t X_data_str,
|
||||
const size_t Y_data_str,
|
||||
const size_t X_mask_str,
|
||||
const size_t Y_mask_str) {
|
||||
int tX = (X + tile_size - 1) / tile_size;
|
||||
int tY = (Y + tile_size - 1) / tile_size;
|
||||
|
||||
for (int i = 0; i < tX; i++) {
|
||||
for (int j = 0; j < tY; j++) {
|
||||
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
|
||||
if (!do_mask) {
|
||||
int loc_x = i * tile_size;
|
||||
int loc_y = j * tile_size;
|
||||
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
|
||||
|
||||
int size_x = std::min(tile_size, X - loc_x);
|
||||
int size_y = std::min(tile_size, Y - loc_y);
|
||||
for (int ii = 0; ii < size_x; ii++) {
|
||||
for (int jj = 0; jj < size_y; jj++) {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
@ -31,6 +31,7 @@ DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
|
@ -41,6 +41,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
|
@ -41,6 +41,7 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
|
193
mlx/backend/common/masked_mm.cpp
Normal file
193
mlx/backend/common/masked_mm.cpp
Normal file
@ -0,0 +1,193 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
inline void mask_matrix(
|
||||
T* data,
|
||||
const bool* mask,
|
||||
int block_size,
|
||||
const int X,
|
||||
const int Y,
|
||||
const size_t X_data_str,
|
||||
const size_t Y_data_str,
|
||||
const size_t X_mask_str,
|
||||
const size_t Y_mask_str) {
|
||||
int tX = (X + block_size - 1) / block_size;
|
||||
int tY = (Y + block_size - 1) / block_size;
|
||||
|
||||
for (int i = 0; i < tX; i++) {
|
||||
for (int j = 0; j < tY; j++) {
|
||||
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
|
||||
if (!do_mask) {
|
||||
int loc_x = i * block_size;
|
||||
int loc_y = j * block_size;
|
||||
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
|
||||
|
||||
int size_x = std::min(block_size, X - loc_x);
|
||||
int size_y = std::min(block_size, Y - loc_y);
|
||||
for (int ii = 0; ii < size_x; ii++) {
|
||||
for (int jj = 0; jj < size_y; jj++) {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[BlockMaskedMM::eval] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
auto& out_mask = inputs[2];
|
||||
|
||||
auto check_transpose = [](const array& arr, bool do_copy) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(true, sty, arr_copy);
|
||||
}
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
bool has_op_mask = inputs.size() > 3;
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre, has_op_mask);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre, has_op_mask);
|
||||
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
auto mask_array = [](const array& mask,
|
||||
float* data,
|
||||
int block_size,
|
||||
int batch_idx,
|
||||
int X,
|
||||
int Y,
|
||||
size_t X_data_str,
|
||||
size_t Y_data_str) {
|
||||
const bool* mask_ptr = mask.data<bool>() +
|
||||
elem_to_loc(mask.shape(-1) * mask.shape(-2) * batch_idx,
|
||||
mask.shape(),
|
||||
mask.strides());
|
||||
|
||||
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
|
||||
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
|
||||
|
||||
return mask_matrix(
|
||||
data,
|
||||
mask_ptr,
|
||||
block_size,
|
||||
X,
|
||||
Y,
|
||||
X_data_str,
|
||||
Y_data_str,
|
||||
X_mask_str,
|
||||
Y_mask_str);
|
||||
};
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
// Adjust pointer
|
||||
float* ai =
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||
float* bi =
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||
float* ci = out.data<float>() + M * N * i;
|
||||
|
||||
// Zero out blocks in a and b if needed
|
||||
if (has_op_mask) {
|
||||
auto& a_mask = inputs[3];
|
||||
mask_array(
|
||||
a_mask,
|
||||
ai,
|
||||
block_size_,
|
||||
i,
|
||||
M,
|
||||
K,
|
||||
a_transposed ? 1 : lda,
|
||||
a_transposed ? lda : 1);
|
||||
|
||||
auto& b_mask = inputs[4];
|
||||
mask_array(
|
||||
b_mask,
|
||||
bi,
|
||||
block_size_,
|
||||
i,
|
||||
K,
|
||||
N,
|
||||
b_transposed ? 1 : ldb,
|
||||
b_transposed ? ldb : 1);
|
||||
}
|
||||
|
||||
// Do matmul
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1.0, // alpha
|
||||
ai,
|
||||
lda,
|
||||
bi,
|
||||
ldb,
|
||||
0.0, // beta
|
||||
ci,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
|
||||
// Zero out blocks in out
|
||||
mask_array(out_mask, ci, block_size_, i, M, N, N, 1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -0,0 +1,323 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
bool has_operand_mask=false>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void block_masked_gemm(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
device T *D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const device bool *out_mask [[buffer(10)]],
|
||||
const device bool *lhs_mask [[buffer(11)]],
|
||||
const device bool *rhs_mask [[buffer(12)]],
|
||||
const constant int* mask_strides [[buffer(13)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// Appease the compiler
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
if(params->batch_ndim > 1) {
|
||||
const constant size_t* mask_batch_strides = batch_strides + 2 * params->batch_ndim;
|
||||
out_mask += elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
||||
|
||||
if(has_operand_mask) {
|
||||
const constant size_t* mask_strides_lhs = mask_batch_strides + params->batch_ndim;
|
||||
const constant size_t* mask_strides_rhs = mask_strides_lhs + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, mask_strides_lhs, mask_strides_rhs, params->batch_ndim);
|
||||
|
||||
lhs_mask += batch_offsets.x;
|
||||
rhs_mask += batch_offsets.y;
|
||||
}
|
||||
}
|
||||
|
||||
// Adjust for batch
|
||||
if(params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
D += c_row * params->ldd + c_col;
|
||||
|
||||
|
||||
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
|
||||
|
||||
// Write zeros and return
|
||||
if(!mask_out) {
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
constexpr short vec_size = 4;
|
||||
|
||||
// Tile threads in threadgroup
|
||||
constexpr short TN = BN / vec_size;
|
||||
constexpr short TM = tgp_size / TN;
|
||||
|
||||
const short thread_idx = simd_group_id * 32 + simd_lane_id;
|
||||
const short bi = thread_idx / TN;
|
||||
const short bj = vec_size * (thread_idx % TN);
|
||||
|
||||
D += bi * params->ldd + bj;
|
||||
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
|
||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
for (short ti = 0; ti < BM; ti += TM) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for(short j = 0; j < vec_size; j++) {
|
||||
D[ti * params->ldd + j] = T(0.);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
short jmax = tgp_bn - bj;
|
||||
jmax = jmax < vec_size ? jmax : vec_size;
|
||||
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
|
||||
for(short j = 0; j < jmax; j++) {
|
||||
D[ti * params->ldd + j] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread typename gemm_kernel::loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread typename gemm_kernel::loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if(!has_operand_mask ||
|
||||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
||||
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
}
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
|
||||
if(!has_operand_mask ||
|
||||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
||||
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
bool M_aligned = (tgp_bm == BM);
|
||||
bool N_aligned = (tgp_bn == BN);
|
||||
|
||||
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if(!has_operand_mask ||
|
||||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
||||
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||
|
||||
// Load elements into threadgroup
|
||||
if (M_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
}
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
}
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
if (!K_aligned) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if(!has_operand_mask ||
|
||||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
||||
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||
|
||||
short2 tile_dims_A_last =
|
||||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
||||
short2 tile_dims_B_last =
|
||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A_last);
|
||||
loader_b.load_safe(tile_dims_B_last);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if(M_aligned && N_aligned) {
|
||||
mma_op.store_result(D, params->ldd);
|
||||
} else {
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, omname, op_mask) \
|
||||
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_op_mask_" #omname)]] \
|
||||
[[kernel]] void block_masked_gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, op_mask>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
device itype *D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
const device bool *out_mask [[buffer(10)]], \
|
||||
const device bool *lhs_mask [[buffer(11)]], \
|
||||
const device bool *rhs_mask [[buffer(12)]], \
|
||||
const constant int* mask_strides [[buffer(13)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true)
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
@ -1064,4 +1064,181 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
using namespace mlx::steel;
|
||||
// assert(inputs.size() == 2);
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
// Return 0s if either input is empty
|
||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||
array zero = array(0, a_pre.dtype());
|
||||
copy_gpu(zero, out, CopyType::Scalar, s);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
int matrix_stride_out = M * N;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = block_size_, bn = block_size_, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << "steel_block_masked_gemm_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned" << "_op_mask_"
|
||||
<< (inputs.size() > 3 ? "T" : "N");
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
// TODO: Explore device-based tuning for swizzle
|
||||
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
||||
|
||||
// Prepare steel matmul params
|
||||
GEMMParams params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ ldb,
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int batch_stride_a = */ int(A_batch_stride.back()),
|
||||
/* const int batch_stride_b = */ int(B_batch_stride.back()),
|
||||
/* const int batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
|
||||
// Prepare launch grid params
|
||||
int tile = 1 << swizzle_log;
|
||||
tm = (tm + tile - 1) / tile;
|
||||
tn = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
std::vector<size_t> batch_strides = A_batch_stride;
|
||||
batch_strides.insert(
|
||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||
|
||||
std::vector<int> mask_strides;
|
||||
|
||||
auto& out_mask = inputs[2];
|
||||
mask_strides.push_back(*(out_mask.strides().end() - 1));
|
||||
mask_strides.push_back(*(out_mask.strides().end() - 2));
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
out_mask.strides().begin(),
|
||||
out_mask.strides().end() - 2);
|
||||
|
||||
if (inputs.size() > 3) {
|
||||
auto& lhs_mask = inputs[3];
|
||||
mask_strides.push_back(*(lhs_mask.strides().end() - 1));
|
||||
mask_strides.push_back(*(lhs_mask.strides().end() - 2));
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
lhs_mask.strides().begin(),
|
||||
lhs_mask.strides().end() - 2);
|
||||
|
||||
compute_encoder.set_input_array(lhs_mask, 11);
|
||||
|
||||
auto& rhs_mask = inputs[4];
|
||||
mask_strides.push_back(*(rhs_mask.strides().end() - 1));
|
||||
mask_strides.push_back(*(rhs_mask.strides().end() - 2));
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
rhs_mask.strides().begin(),
|
||||
rhs_mask.strides().end() - 2);
|
||||
|
||||
compute_encoder.set_input_array(rhs_mask, 12);
|
||||
}
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
|
||||
compute_encoder.set_input_array(out_mask, 10);
|
||||
set_vector_bytes(compute_encoder, mask_strides, 13);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -99,6 +99,7 @@ NO_GPU(Subtract)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Transpose)
|
||||
NO_GPU(Inverse)
|
||||
|
||||
|
166
mlx/ops.cpp
166
mlx/ops.cpp
@ -3572,6 +3572,172 @@ array addmm(
|
||||
return out;
|
||||
}
|
||||
|
||||
/** Compute matrix product with tile-level masking */
|
||||
array block_masked_mm(
|
||||
array a,
|
||||
array b,
|
||||
int block_size,
|
||||
std::optional<array> mask_out /* = std::nullopt */,
|
||||
std::optional<array> mask_lhs /* = std::nullopt */,
|
||||
std::optional<array> mask_rhs /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// If no masks, just perform regular matmul
|
||||
if (!mask_out && !mask_lhs && !mask_rhs) {
|
||||
return matmul(a, b, s);
|
||||
}
|
||||
|
||||
bool has_out_mask = mask_out.has_value();
|
||||
bool has_operand_mask = mask_lhs.has_value() || mask_rhs.has_value();
|
||||
|
||||
// Check valid tile sizes
|
||||
// TODO: Add support for 16x16 tile
|
||||
if (block_size != 32 && block_size != 64) {
|
||||
std::ostringstream msg;
|
||||
msg << "[block_masked_mm] Only block_sizes 32, 64 are supported."
|
||||
<< "Got block size " << block_size << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Do shape checks for operands
|
||||
int in_a_ndim = a.ndim();
|
||||
int in_b_ndim = b.ndim();
|
||||
|
||||
if (a.ndim() == 0 || b.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[block_masked_mm] Got 0 dimension input. Inputs must "
|
||||
"have at least one dimension.");
|
||||
}
|
||||
|
||||
if (a.ndim() == 1) {
|
||||
// Insert a singleton dim in the beginning
|
||||
a = reshape(a, {1, -1}, s);
|
||||
}
|
||||
if (b.ndim() == 1) {
|
||||
// Insert a singleton dim at the end
|
||||
b = reshape(b, {-1, 1}, s);
|
||||
}
|
||||
|
||||
if (a.shape(-1) != b.shape(-2)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[block_masked_mm] Last dimension of first input with shape "
|
||||
<< a.shape() << " must match second to last dimension of"
|
||||
<< " second input with shape " << b.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b);
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[block_masked_mm] Only real floating point types are supported but "
|
||||
<< a.dtype() << " and " << b.dtype()
|
||||
<< " were provided which results in " << out_type
|
||||
<< ", which is not a real floating point type.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
a = astype(a, out_type, s);
|
||||
b = astype(b, out_type, s);
|
||||
|
||||
// Handle broadcasting
|
||||
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||
|
||||
auto bsx_shape = broadcast_shapes(bsx_a, bsx_b);
|
||||
|
||||
bsx_shape.push_back(1);
|
||||
bsx_shape.push_back(1);
|
||||
int nd = bsx_shape.size();
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
// Prepare A
|
||||
bsx_shape[nd - 2] = M;
|
||||
bsx_shape[nd - 1] = K;
|
||||
a = broadcast_to(a, bsx_shape, s);
|
||||
|
||||
// Prepare B
|
||||
bsx_shape[nd - 2] = K;
|
||||
bsx_shape[nd - 1] = N;
|
||||
b = broadcast_to(b, bsx_shape, s);
|
||||
|
||||
// Get output shape
|
||||
auto out_shape = bsx_shape;
|
||||
out_shape[nd - 2] = M;
|
||||
out_shape[nd - 1] = N;
|
||||
|
||||
// Determine mask shape requirments
|
||||
int tm = (M + block_size - 1) / block_size;
|
||||
int tn = (N + block_size - 1) / block_size;
|
||||
int tk = (K + block_size - 1) / block_size;
|
||||
|
||||
// Broadcast and astype mask
|
||||
auto broadcast_mask = [](array mask,
|
||||
std::vector<int>& bs_shape,
|
||||
int y,
|
||||
int x,
|
||||
StreamOrDevice s) {
|
||||
int nd_bsx = bs_shape.size();
|
||||
bs_shape[nd_bsx - 2] = y;
|
||||
bs_shape[nd_bsx - 1] = x;
|
||||
mask = astype(mask, bool_, s);
|
||||
return broadcast_to(mask, bs_shape, s);
|
||||
};
|
||||
|
||||
// Out mask
|
||||
array mask_out_p = mask_out.value_or(array({true}));
|
||||
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
||||
std::vector<int> ex_dims;
|
||||
if (in_a_ndim == 1)
|
||||
ex_dims.push_back(-2);
|
||||
if (in_b_ndim == 1)
|
||||
ex_dims.push_back(-1);
|
||||
mask_out_p = expand_dims(mask_out_p, ex_dims, s);
|
||||
}
|
||||
mask_out_p = broadcast_mask(mask_out_p, bsx_shape, tm, tn, s);
|
||||
|
||||
std::vector<array> inputs = {a, b, mask_out_p};
|
||||
|
||||
// Operand masks
|
||||
if (has_operand_mask) {
|
||||
// LHS mask
|
||||
array mask_lhs_p = mask_lhs.value_or(array({true}));
|
||||
if (in_a_ndim == 1) {
|
||||
mask_lhs_p = expand_dims(mask_lhs_p, -2, s);
|
||||
}
|
||||
mask_lhs_p = broadcast_mask(mask_lhs_p, bsx_shape, tm, tk, s);
|
||||
|
||||
// RHS mask
|
||||
array mask_rhs_p = mask_rhs.value_or(array({true}));
|
||||
if (in_b_ndim == 1) {
|
||||
mask_rhs_p = expand_dims(mask_lhs_p, -1, s);
|
||||
}
|
||||
mask_rhs_p = broadcast_mask(mask_rhs_p, bsx_shape, tk, tn, s);
|
||||
|
||||
inputs.push_back(mask_lhs_p);
|
||||
inputs.push_back(mask_rhs_p);
|
||||
}
|
||||
|
||||
// Caculate array
|
||||
auto out = array(
|
||||
out_shape,
|
||||
out_type,
|
||||
std::make_shared<BlockMaskedMM>(to_stream(s), block_size),
|
||||
std::move(inputs));
|
||||
|
||||
// Remove the possibly inserted singleton dimensions
|
||||
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
||||
out_shape.erase(
|
||||
out_shape.end() - ((in_a_ndim == 1) ? 2 : 1),
|
||||
out_shape.end() - ((in_b_ndim == 1) ? 0 : 1));
|
||||
out = reshape(out, out_shape, s);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
array diagonal(
|
||||
const array& a,
|
||||
int offset /* = 0 */,
|
||||
|
10
mlx/ops.h
10
mlx/ops.h
@ -1185,6 +1185,16 @@ array addmm(
|
||||
const float& beta = 1.f,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Compute matrix product with block masking */
|
||||
array block_masked_mm(
|
||||
array a,
|
||||
array b,
|
||||
int block_size,
|
||||
std::optional<array> mask_out = std::nullopt,
|
||||
std::optional<array> mask_lhs = std::nullopt,
|
||||
std::optional<array> mask_rhs = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Extract a diagonal or construct a diagonal array */
|
||||
array diagonal(
|
||||
const array& a,
|
||||
|
@ -3272,6 +3272,59 @@ std::pair<std::vector<array>, std::vector<int>> Tanh::vmap(
|
||||
return {{tanh(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> BlockMaskedMM::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
auto& cotan = cotangents[0];
|
||||
std::vector<int> reorder(cotan.ndim());
|
||||
std::iota(reorder.begin(), reorder.end(), 0);
|
||||
std::iter_swap(reorder.end() - 1, reorder.end() - 2);
|
||||
bool has_op_mask = primals.size() > 3;
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
// M X N * (K X N).T -> M X K
|
||||
auto b_t = transpose(primals[1], reorder, stream());
|
||||
auto out_mask = primals[2];
|
||||
auto lhs_mask =
|
||||
has_op_mask ? std::make_optional<array>(primals[3]) : std::nullopt;
|
||||
auto rhs_mask_t = has_op_mask
|
||||
? std::make_optional<array>(transpose(primals[4], reorder, stream()))
|
||||
: std::nullopt;
|
||||
|
||||
auto grad = block_masked_mm(
|
||||
cotan, b_t, block_size_, lhs_mask, out_mask, rhs_mask_t, stream());
|
||||
|
||||
vjps.push_back(grad);
|
||||
|
||||
} else if (arg == 1) {
|
||||
// (M X K).T * M X N -> K X N
|
||||
auto a_t = transpose(primals[0], reorder, stream());
|
||||
auto out_mask = primals[2];
|
||||
auto lhs_mask_t = has_op_mask
|
||||
? std::make_optional<array>(transpose(primals[3], reorder, stream()))
|
||||
: std::nullopt;
|
||||
auto rhs_mask =
|
||||
has_op_mask ? std::make_optional<array>(primals[4]) : std::nullopt;
|
||||
|
||||
auto grad = block_masked_mm(
|
||||
a_t, cotan, block_size_, rhs_mask, lhs_mask_t, out_mask, stream());
|
||||
|
||||
vjps.push_back(grad);
|
||||
} else {
|
||||
vjps.push_back(zeros_like(primals[arg], stream()));
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
bool BlockMaskedMM::is_equivalent(const Primitive& other) const {
|
||||
const BlockMaskedMM& a_other = static_cast<const BlockMaskedMM&>(other);
|
||||
return (block_size_ == a_other.block_size_);
|
||||
}
|
||||
|
||||
std::vector<array> Transpose::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@ -443,6 +443,29 @@ class AsStrided : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class BlockMaskedMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit BlockMaskedMM(Stream stream, int block_size)
|
||||
: UnaryPrimitive(stream), block_size_(block_size){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(BlockMaskedMM)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
int block_size_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Broadcast : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Broadcast(Stream stream, const std::vector<int>& shape)
|
||||
|
@ -3645,6 +3645,44 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: ``alpha * (a @ b) + beta * c``
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"block_masked_mm",
|
||||
&block_masked_mm,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"block_size"_a = 64,
|
||||
"mask_out"_a = nb::none(),
|
||||
"mask_lhs"_a = nb::none(),
|
||||
"mask_rhs"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: array, mask_lhs: array, mask_rhs: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Matrix multiplication with block masking.
|
||||
|
||||
Perform the (possibly batched) matrix multiplication of two arrays and with blocks
|
||||
of size ``block_size x block_size`` optionally masked out.
|
||||
|
||||
Assuming ``a`` with shape (..., `M`, `K`) and b with shape (..., `K`, `N`)
|
||||
|
||||
* ``lhs_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `K` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
* ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
* ``out_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
Note: Only ``block_size=64`` and ``block_size=32`` are currently supported
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``)
|
||||
mask_out (array, optional): Boolean mask for output (default: ``None``)
|
||||
mask_lhs (array, optional): Boolean mask for a (default: ``None``)
|
||||
mask_rhs (array, optional): Boolean mask for b (default: ``None``)
|
||||
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"diagonal",
|
||||
&diagonal,
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import unittest
|
||||
@ -681,6 +681,119 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
mx.eval(c)
|
||||
self.assertEqual(c.shape, (0, 0))
|
||||
|
||||
def test_block_masked_matmul(self):
|
||||
def np_block_masked_mm(
|
||||
a, b, block_size, out_mask=None, lhs_mask=None, rhs_mask=None
|
||||
):
|
||||
# Get mask adjusted shapes
|
||||
M = a.shape[-2]
|
||||
N = b.shape[-1]
|
||||
K = a.shape[-1]
|
||||
|
||||
# Expand mask dims
|
||||
def expand_mask(mask, block_size, Y, X):
|
||||
mask = np.expand_dims(mask, (-3, -1))
|
||||
mask_shape = list(mask.shape)
|
||||
mask_shape[-1] = block_size
|
||||
x = mask_shape[-2] * block_size
|
||||
mask_shape[-3] = block_size
|
||||
y = mask_shape[-4] * block_size
|
||||
mask = np.broadcast_to(mask, mask_shape)
|
||||
mask_shape = mask_shape[:-4] + [y, x]
|
||||
return mask.reshape(mask_shape)[..., :Y, :X]
|
||||
|
||||
if lhs_mask is not None:
|
||||
lhs_mask = expand_mask(lhs_mask, block_size, M, K)
|
||||
a = lhs_mask * a
|
||||
|
||||
if rhs_mask is not None:
|
||||
rhs_mask = expand_mask(rhs_mask, block_size, K, N)
|
||||
b = rhs_mask * b
|
||||
|
||||
out = a @ b
|
||||
|
||||
if out_mask is not None:
|
||||
out_mask = expand_mask(out_mask, block_size, M, N)
|
||||
out = out * out_mask
|
||||
return out
|
||||
|
||||
def test_shape(M, N, K, block_size, transpose=False, np_dtype=np.float32):
|
||||
with self.subTest(
|
||||
M=M,
|
||||
N=N,
|
||||
K=K,
|
||||
block_size=block_size,
|
||||
np_dtype=np_dtype,
|
||||
transpose=transpose,
|
||||
):
|
||||
tm = (M + block_size - 1) // block_size
|
||||
tn = (N + block_size - 1) // block_size
|
||||
tk = (K + block_size - 1) // block_size
|
||||
|
||||
a_np = np.random.normal(size=(M, K)).astype(np_dtype)
|
||||
b_np = np.random.normal(size=(K, N)).astype(np_dtype)
|
||||
|
||||
a_np_mask = np.random.normal(size=(tm, tk)) < 0.0
|
||||
b_np_mask = np.random.normal(size=(tk, tn)) < 0.0
|
||||
out_np_mask = np.random.normal(size=(tm, tn)) < 0.0
|
||||
|
||||
a_mx, b_mx, a_mx_mask, b_mx_mask, out_mx_mask = map(
|
||||
mx.array, (a_np, b_np, a_np_mask, b_np_mask, out_np_mask)
|
||||
)
|
||||
|
||||
if transpose:
|
||||
b_np = np.random.normal(size=(N, K)).astype(np_dtype)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
b_np = b_np.T
|
||||
b_mx = b_mx.T
|
||||
|
||||
out_np = np_block_masked_mm(
|
||||
a_np, b_np, block_size, out_np_mask, a_np_mask, b_np_mask
|
||||
)
|
||||
out_mx = mx.block_masked_mm(
|
||||
a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask
|
||||
)
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||
|
||||
out_np = np_block_masked_mm(a_np, b_np, block_size, out_np_mask)
|
||||
out_mx = mx.block_masked_mm(a_mx, b_mx, block_size, out_mx_mask)
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||
|
||||
out_np = np_block_masked_mm(
|
||||
a_np, b_np, block_size, None, a_np_mask, b_np_mask
|
||||
)
|
||||
out_mx = mx.block_masked_mm(
|
||||
a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask
|
||||
)
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||
|
||||
shapes = (
|
||||
(16, 16, 16, 32),
|
||||
(64, 64, 16, 32),
|
||||
(128, 128, 128, 32),
|
||||
(256, 256, 128, 64),
|
||||
)
|
||||
|
||||
for M, N, K, block_size in shapes:
|
||||
test_shape(M, N, K, block_size, transpose=False)
|
||||
test_shape(M, N, K, block_size, transpose=True)
|
||||
|
||||
# Test gemv
|
||||
a_np = np.random.normal(size=(64, 64)).astype(np.float32)
|
||||
b_np = np.random.normal(size=(64,)).astype(np.float32)
|
||||
mask_np = np.array([True, False]).astype(np.bool_)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
mask_mx = mx.array(mask_np)
|
||||
|
||||
c_mx = mx.block_masked_mm(a_mx, b_mx, 32, mask_mx)
|
||||
c_np = a_np @ b_np
|
||||
c_np[32:] = 0.0
|
||||
|
||||
self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user