mlx/mlx/backend/metal/kernels/gemv_masked.h
Awni Hannun 40c62c1321
Use int64 stride everywhere (#1671)
* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
2024-12-09 11:09:02 -08:00

820 lines
26 KiB
C++

// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
struct _NoMask {
char x;
constexpr METAL_FUNC operator bool() {
return true;
}
constexpr METAL_FUNC operator bool() const threadgroup {
return true;
}
constexpr METAL_FUNC operator bool() const device {
return true;
}
constexpr METAL_FUNC operator bool() const constant {
return true;
}
};
typedef struct _NoMask nomask_t;
template <typename OutT, typename InT = OutT>
struct ScaleOp {
OutT scale;
METAL_FUNC OutT apply(InT x) const {
return static_cast<OutT>(x) * scale;
}
};
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
struct GEMVKernel {
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
MLX_MTL_CONST int blockM = threadsM * TM;
MLX_MTL_CONST int blockN = threadsN * TN;
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
static_assert(
SN == 8 || SN == 16 || SN == 32,
"gemv block must have a width of 8, 16, or 32");
static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM");
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
MLX_MTL_CONST bool has_mul_operand_mask =
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
MLX_MTL_CONST bool has_mul_output_mask =
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
// into blocks of (blockM, blockN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for
// the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated blockM outputs
//
// Edge case handling:
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
static METAL_FUNC void
load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src[src_offset + tn];
}
}
static METAL_FUNC void load_safe(
const device T* src,
thread T dst[TN],
const int src_offset = 0,
const int src_size = TN) {
if (src_offset + TN <= src_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src[src_offset + tn];
}
} else { // Edgecase
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
}
}
}
static METAL_FUNC void run(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& matrix_ld [[buffer(6)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
(void)lid;
// Thread local accumulation results
thread T result[TM] = {0};
thread T inter[TN];
thread T v_coeff[TN];
const int thrM = SN != 32 ? simd_lid / SN : 0;
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
int bm = (simdM + thrM) * TM;
int bn = (simdN + thrN) * TN;
// Block position
int out_row = tid.x * blockM + bm;
// Exit simdgroup if rows out of bound
if (out_row >= out_vec_size)
return;
// Adjust tail simdgroup to ensure in bound reads
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
// Prepare mask offsets
const constant int* out_mask_strides = mask_strides;
const constant int* mat_mask_strides =
mask_strides + (has_output_mask ? 2 : 0);
const constant int* vec_mask_strides =
mat_mask_strides + (has_operand_mask ? 2 : 0);
const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x);
const int out_mask_offset =
!has_output_mask ? 0 : m_block_idx * out_mask_strides[1];
int mat_mask_offset =
!has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1];
int vec_mask_offset = 0;
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0];
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1];
T out_scale{1};
// Check output mask
if (has_output_mask) {
auto mask_out = out_mask[out_mask_offset];
// Write zeros and return if mask is 0
if (!mask_out) {
if (simdN == 0 && thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = T(0.);
}
}
return;
}
// Store scalar if multiplicative mask
if (has_mul_output_mask) {
out_scale = T(mask_out);
}
}
// Advance matrix
mat += out_row * matrix_ld;
// Prepare for loop
constexpr const uniform<int> loop_stride = make_uniform(blockN);
const uniform<int> in_size = make_uniform(in_vec_size);
const uniform<int> n_iter = in_size / loop_stride;
const uniform<int> last_iter = loop_stride * n_iter;
const uniform<int> leftover = in_size - last_iter;
// Loop over in_vec in blocks of blockN
for (int i = 0; i < n_iter; ++i) {
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
load_unsafe(in_vec, v_coeff, bn);
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
}
}
// Per thread work loop
int mat_offset = 0;
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
load_unsafe(mat, inter, mat_offset + bn);
// Accumulate results
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
mat_offset += matrix_ld;
}
}
bn += blockN;
mat_mask_offset += mat_mask_step;
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
load_safe(in_vec, v_coeff, bn, in_size);
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
}
}
// Per thread work loop
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
// Accumulate results
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
}
}
// Apply out scale
if (has_mul_output_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
result[tm] *= out_scale;
}
}
// Simdgroup accumulations
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
MLX_MTL_PRAGMA_UNROLL
for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
result[tm] += simd_shuffle_down(result[tm], sn);
}
}
// Threadgroup accumulation results
if (needs_tgp_reduction) {
threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
if (thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
tgp_results[tm] = result[tm];
}
threadgroup_barrier(mem_flags::mem_none);
if (sgN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int sgn = 1; sgn < BN; sgn++) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
result[tm] += tgp_results[sgn * (blockM + TM) + tm];
}
}
}
}
}
// Write outputs
if (simdN == 0 && thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = result[tm];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
struct GEMVTKernel {
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
MLX_MTL_CONST int blockM = threadsM * TM;
MLX_MTL_CONST int blockN = threadsN * TN;
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
MLX_MTL_CONST bool has_mul_operand_mask =
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
MLX_MTL_CONST bool has_mul_output_mask =
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// into blocks of (blockM, blockN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then accumulates its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
static METAL_FUNC void run(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
(void)lid;
// Thread local accumulation results
T result[TN] = {0};
T inter[TN];
T v_coeff[TM];
const int thrM = SN != 32 ? simd_lid / SN : 0;
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
const int simdM = SM * sgM;
const int simdN = SN * sgN;
int cm = (simdM + thrM);
int cn = (simdN + thrN);
int bm = cm * TM;
int bn = cn * TN;
int out_col = tid.x * blockN + bn;
// Prepare mask offsets
const constant int* out_mask_strides = mask_strides;
const constant int* mat_mask_strides =
out_mask_strides + (has_output_mask ? 2 : 0);
const constant int* vec_mask_strides =
mat_mask_strides + (has_operand_mask ? 2 : 0);
const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x);
const int out_mask_offset =
!has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0];
int mat_mask_offset =
!has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0];
int vec_mask_offset = 0;
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1];
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0];
T out_scale{1};
// Check output mask
if (has_output_mask) {
auto mask_out = out_mask[out_mask_offset];
// Write zeros and return if mask is 0
if (!mask_out) {
if (cm == 0 && out_col < out_vec_size) {
if (out_col + TN <= out_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
out_vec[out_col + tn] = T(0.);
}
} else {
for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
out_vec[out_col + tn] = T(0.);
}
}
}
return;
}
// Store scalar if multiplicative mask
if (has_mul_output_mask) {
out_scale = T(mask_out);
}
}
// Prepare for loop
constexpr const uniform<int> loop_stride = make_uniform(blockM);
const uniform<int> in_size = make_uniform(in_vec_size);
const uniform<int> n_iter = in_size / loop_stride;
const uniform<int> last_iter = loop_stride * n_iter;
const uniform<int> leftover = in_size - last_iter;
// Edgecase handling
if (out_col < out_vec_size) {
out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
// Per thread accumulation main loop
for (int i = 0; i < n_iter; ++i) {
// Adding a threadgroup_barrier improves performance slightly
// This is possibly it may help exploit cache better
threadgroup_barrier(mem_flags::mem_none);
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm];
}
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] *= block_scale;
}
}
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
bm += blockM;
mat_mask_offset += mat_mask_step;
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
v_coeff[tm] = in_vec[bm + tm];
if (has_mul_operand_mask) {
v_coeff[tm] *= block_scale;
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
}
// Apply out scale
if (has_mul_output_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] *= out_scale;
}
}
// Simdgroup accumulations
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
MLX_MTL_PRAGMA_UNROLL
for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
result[tn] += simd_shuffle_down(result[tn], SN * sm);
}
}
// Threadgroup accumulation results
if (needs_tgp_reduction) {
threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
if (thrM == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
tgp_results[tn] = result[tn];
}
threadgroup_barrier(mem_flags::mem_none);
if (sgM == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int sgm = 1; sgm < BM; sgm++) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] += tgp_results[sgm * (blockN + TN) + tn];
}
}
}
}
}
// Threadgroup accumulation and writing out results
if (cm == 0 && out_col < out_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
out_vec[out_col + j] = result[j];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
/// Matrix vector multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch> /* Batch ndim > 1 */
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant int64_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant int64_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel =
GEMVKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
threadgroup T tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
// Update batch offsets
if (kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if (has_output_mask) {
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
const constant auto* mask_strides_mat = mask_batch_strides;
const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
mat_mask += batch_offsets.x;
vec_mask += batch_offsets.y;
}
} else {
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if (has_output_mask) {
out_mask += tid.z * mask_batch_strides[0];
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
mat_mask += tid.z * mask_batch_strides[0];
vec_mask += tid.z * mask_batch_strides[batch_ndim];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
out_mask,
mat_mask,
vec_mask,
mask_strides,
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch> /* Batch ndim > 1 */
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant int64_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant int64_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel =
GEMVTKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
threadgroup T tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
// Update batch offsets
if (kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if (has_output_mask) {
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
const constant auto* mask_strides_mat = mask_batch_strides;
const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
mat_mask += batch_offsets.x;
vec_mask += batch_offsets.y;
}
} else {
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if (has_output_mask) {
out_mask += tid.z * mask_batch_strides[0];
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
mat_mask += tid.z * mask_batch_strides[0];
vec_mask += tid.z * mask_batch_strides[batch_ndim];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
out_mask,
mat_mask,
vec_mask,
mask_strides,
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}