mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Masked gemv (#1211)
This commit is contained in:
parent
fe3167d7ea
commit
2d6cd47713
@ -37,6 +37,7 @@ endfunction(build_kernel)
|
||||
build_kernel(arg_reduce)
|
||||
build_kernel(conv steel/conv/params.h)
|
||||
build_kernel(gemv steel/utils.h)
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
build_kernel(layer_norm)
|
||||
build_kernel(random)
|
||||
build_kernel(rms_norm)
|
||||
|
@ -17,29 +17,250 @@ using namespace metal;
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
struct GEMVKernel {
|
||||
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||
MLX_MTL_CONST int threadsM = BM * SM;
|
||||
MLX_MTL_CONST int threadsN = BN * SN;
|
||||
|
||||
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||
MLX_MTL_CONST int blockM = threadsM * TM;
|
||||
MLX_MTL_CONST int blockN = threadsN * TN;
|
||||
|
||||
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
||||
|
||||
static_assert(
|
||||
SN == 8 || SN == 16 || SN == 32,
|
||||
"gemv block must have a width of 8, 16, or 32");
|
||||
|
||||
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
|
||||
// into blocks of (blockM, blockN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// 1. A thread loads TN elements each from mat along TM rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for
|
||||
// the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across
|
||||
// the rows. These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated blockM outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results
|
||||
// remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted
|
||||
// inwards such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
|
||||
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
|
||||
|
||||
static METAL_FUNC void
|
||||
load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
dst[tn] = src[src_offset + tn];
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC void load_safe(
|
||||
const device T* src,
|
||||
thread T dst[TN],
|
||||
const int src_offset = 0,
|
||||
const int src_size = TN) {
|
||||
if (src_offset + TN <= src_size) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
dst[tn] = src[src_offset + tn];
|
||||
}
|
||||
} else { // Edgecase
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& matrix_ld [[buffer(6)]],
|
||||
const constant float& alpha [[buffer(7)]],
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
// Appease compiler
|
||||
(void)lid;
|
||||
|
||||
// Thread local accumulation results
|
||||
thread T result[TM] = {0};
|
||||
thread T inter[TN];
|
||||
thread T v_coeff[TN];
|
||||
|
||||
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
||||
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
||||
|
||||
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
||||
|
||||
const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
|
||||
const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
|
||||
|
||||
int bm = (simdM + thrM) * TM;
|
||||
int bn = (simdN + thrN) * TN;
|
||||
|
||||
// Block position
|
||||
int out_row = tid.x * blockM + bm;
|
||||
|
||||
// Exit simdgroup if rows out of bound
|
||||
if (out_row >= out_vec_size)
|
||||
return;
|
||||
|
||||
// Adjust tail simdgroup to ensure in bound reads
|
||||
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
||||
|
||||
// Advance matrix
|
||||
mat += out_row * matrix_ld;
|
||||
|
||||
constexpr const uniform<int> loop_stride = make_uniform(blockN);
|
||||
const uniform<int> in_size = make_uniform(in_vec_size);
|
||||
const uniform<int> n_iter = in_size / loop_stride;
|
||||
const uniform<int> last_iter = loop_stride * n_iter;
|
||||
const uniform<int> leftover = in_size - last_iter;
|
||||
|
||||
// Loop over in_vec in blocks of blockN
|
||||
for (int i = 0; i < n_iter; ++i) {
|
||||
load_unsafe(in_vec, v_coeff, bn);
|
||||
|
||||
// Per thread work loop
|
||||
int mat_offset = 0;
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
load_unsafe(mat, inter, mat_offset + bn);
|
||||
|
||||
// Accumulate results
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
|
||||
mat_offset += matrix_ld;
|
||||
}
|
||||
|
||||
bn += blockN;
|
||||
}
|
||||
|
||||
if (leftover > 0) {
|
||||
load_safe(in_vec, v_coeff, bn, in_size);
|
||||
|
||||
// Per thread work loop
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
|
||||
|
||||
// Accumulate results
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Simdgroup accumulations
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
|
||||
result[tm] += simd_shuffle_down(result[tm], sn);
|
||||
}
|
||||
}
|
||||
|
||||
// Threadgroup accumulation results
|
||||
if (needs_tgp_reduction) {
|
||||
threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
|
||||
if (thrN == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
tgp_results[tm] = result[tm];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (sgN == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int sgn = 1; sgn < BN; sgn++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
result[tm] += tgp_results[sgn * (blockM + TM) + tm];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write outputs
|
||||
if (simdN == 0 && thrN == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
if (kDoAxpby) {
|
||||
out_vec[out_row + tm] = static_cast<T>(alpha) * result[tm] +
|
||||
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
|
||||
} else {
|
||||
out_vec[out_row + tm] = result[tm];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
struct GEMVTKernel {
|
||||
MLX_MTL_CONST int threadsM = BM * SM;
|
||||
MLX_MTL_CONST int threadsN = BN * SN;
|
||||
|
||||
MLX_MTL_CONST int blockM = threadsM * TM;
|
||||
MLX_MTL_CONST int blockN = threadsN * TN;
|
||||
|
||||
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
||||
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
// into blocks of (blockM, blockN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then accumulates its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across
|
||||
// the rows. These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
@ -49,7 +270,8 @@ struct GEMVKernel {
|
||||
// * The last thread that partially overlaps with the matrix is shifted
|
||||
// inwards such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
||||
MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
|
||||
MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat [[buffer(0)]],
|
||||
@ -70,230 +292,113 @@ struct GEMVKernel {
|
||||
// Appease compiler
|
||||
(void)lid;
|
||||
|
||||
// Threadgroup in_vec cache
|
||||
threadgroup T* in_vec_block = tgp_memory + simd_lid * TN * 2;
|
||||
|
||||
// Thread local accumulation results
|
||||
thread T result[TM] = {0};
|
||||
thread T inter[TN];
|
||||
thread T v_coeff[TN];
|
||||
|
||||
// Block position
|
||||
int out_row = (tid.x * BM + simd_gid) * TM;
|
||||
|
||||
// Exit simdgroup if rows out of bound
|
||||
if (out_row >= out_vec_size)
|
||||
return;
|
||||
|
||||
// Adjust tail simdgroup to ensure in bound reads
|
||||
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
||||
|
||||
// Advance matrix
|
||||
mat += out_row * marix_ld;
|
||||
|
||||
// Loop over in_vec in blocks of BN * TN
|
||||
for (int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Prefetch in_vector for threadgroup use
|
||||
if (simd_gid == 0) {
|
||||
// Main load loop
|
||||
if (bn + TN <= in_vec_size) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[tn] = in_vec[bn + tn];
|
||||
}
|
||||
|
||||
} else { // Edgecase
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load for all rows
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] = in_vec_block[tn];
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
if (bn + TN <= in_vec_size) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[tm * marix_ld + bn + tn];
|
||||
}
|
||||
|
||||
} else { // Edgecase
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
int col_idx =
|
||||
(bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
|
||||
inter[tn] = mat[tm * marix_ld + col_idx];
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate results
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Simdgroup accumulations
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
result[tm] = simd_sum(result[tm]);
|
||||
}
|
||||
|
||||
// Write outputs
|
||||
if (simd_lid == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
if (kDoAxpby) {
|
||||
out_vec[out_row + tm] = static_cast<T>(alpha) * result[tm] +
|
||||
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
|
||||
} else {
|
||||
out_vec[out_row + tm] = result[tm];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
struct GEMVTKernel {
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then accumulates its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across
|
||||
// the rows. These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results
|
||||
// remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted
|
||||
// inwards such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant float& alpha [[buffer(7)]],
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
// Appease compiler
|
||||
(void)simd_gid;
|
||||
(void)simd_lid;
|
||||
|
||||
// Thread local accumulation results
|
||||
T result[TN] = {0};
|
||||
T inter[TN];
|
||||
T v_coeff[TM];
|
||||
|
||||
// Threadgroup accumulation results
|
||||
threadgroup T* tgp_results = tgp_memory + lid.x * BM * TN;
|
||||
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
||||
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
||||
|
||||
int out_col = (tid.x * BN + lid.x) * TN;
|
||||
int in_row = lid.y * TM;
|
||||
const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
|
||||
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
||||
|
||||
const int simdM = SM * sgM;
|
||||
const int simdN = SN * sgN;
|
||||
|
||||
int cm = (simdM + thrM);
|
||||
int cn = (simdN + thrN);
|
||||
|
||||
int bm = cm * TM;
|
||||
int bn = cn * TN;
|
||||
|
||||
int out_col = tid.x * blockN + bn;
|
||||
|
||||
constexpr const uniform<int> loop_stride = make_uniform(blockM);
|
||||
const uniform<int> in_size = make_uniform(in_vec_size);
|
||||
const uniform<int> n_iter = in_size / loop_stride;
|
||||
const uniform<int> last_iter = loop_stride * n_iter;
|
||||
const uniform<int> leftover = in_size - last_iter;
|
||||
|
||||
// Edgecase handling
|
||||
if (out_col < out_vec_size) {
|
||||
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
|
||||
|
||||
// Per thread accumulation main loop
|
||||
int bm = in_row;
|
||||
for (; bm < in_vec_size; bm += BM * TM) {
|
||||
for (int i = 0; i < n_iter; ++i) {
|
||||
// Adding a threadgroup_barrier improves performance slightly
|
||||
// This is possibly it may help exploit cache better
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (bm + TM <= in_vec_size) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
|
||||
bm += blockM;
|
||||
}
|
||||
|
||||
if (leftover > 0) {
|
||||
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
|
||||
} else { // Edgecase handling
|
||||
for (int tm = 0; bm + tm < in_vec_size; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Threadgroup collection
|
||||
|
||||
// Simdgroup accumulations
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TN; i++) {
|
||||
tgp_results[lid.y * TN + i] = result[i];
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
|
||||
result[tn] += simd_shuffle_down(result[tn], SN * sm);
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Threadgroup accumulation results
|
||||
if (needs_tgp_reduction) {
|
||||
threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
|
||||
if (thrM == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
tgp_results[tn] = result[tn];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (sgM == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int sgm = 1; sgm < BM; sgm++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += tgp_results[sgm * (blockN + TN) + tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Threadgroup accumulation and writing out results
|
||||
if (lid.y == 0 && out_col < out_vec_size) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int i = 1; i < BM; i++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
result[j] += tgp_results[i * TN + j];
|
||||
}
|
||||
}
|
||||
|
||||
if (cm == 0 && out_col < out_vec_size) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
if (kDoAxpby) {
|
||||
@ -313,13 +418,15 @@ struct GEMVTKernel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoNCBatch, /* Batch ndim > 1 */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv(
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
@ -339,8 +446,9 @@ template <
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
|
||||
threadgroup T tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
if (kDoNCBatch) {
|
||||
@ -373,17 +481,19 @@ template <
|
||||
alpha,
|
||||
beta,
|
||||
bias_stride,
|
||||
tgp_memory,
|
||||
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \
|
||||
"_nc" #nc "_axpby" #axpby)]] [[kernel]] void \
|
||||
gemv<itype, bm, bn, tm, tn, nc, axpby>( \
|
||||
#define instantiate_gemv_helper( \
|
||||
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
|
||||
"_tm" #tm "_tn" #tn "_nc" #nc \
|
||||
"_axpby" #axpby)]] [[kernel]] void \
|
||||
gemv<itype, bm, bn, sm, sn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
@ -405,11 +515,11 @@ template <
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 1) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
@ -423,11 +533,13 @@ instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_bs(
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_gather(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
@ -452,8 +564,9 @@ template <
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, false>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, false>;
|
||||
threadgroup T tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
uint32_t indx_vec;
|
||||
uint32_t indx_mat;
|
||||
@ -501,47 +614,47 @@ template <
|
||||
alpha,
|
||||
beta,
|
||||
batch_ndim, // Not used
|
||||
tgp_memory,
|
||||
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
gemv_bs<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* index_batch_strides [[buffer(11)]], \
|
||||
const constant int& vector_batch_ndim [[buffer(12)]], \
|
||||
const constant int* vector_batch_shape [[buffer(13)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]], \
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]], \
|
||||
const constant int* matrix_batch_shape [[buffer(16)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]], \
|
||||
const constant uint32_t* vec_indices [[buffer(18)]], \
|
||||
const constant uint32_t* mat_indices [[buffer(19)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
|
||||
template [[host_name("gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
||||
"_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \
|
||||
gemv_gather<itype, bm, bn, sm, sn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* index_batch_strides [[buffer(11)]], \
|
||||
const constant int& vector_batch_ndim [[buffer(12)]], \
|
||||
const constant int* vector_batch_shape [[buffer(13)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]], \
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]], \
|
||||
const constant int* matrix_batch_shape [[buffer(16)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]], \
|
||||
const constant uint32_t* vec_indices [[buffer(18)]], \
|
||||
const constant uint32_t* mat_indices [[buffer(19)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_bs_blocks(name, itype) \
|
||||
instantiate_gemv_bs_helper(name, itype, 4, 32, 1, 4) \
|
||||
instantiate_gemv_bs_helper(name, itype, 4, 32, 4, 4) \
|
||||
instantiate_gemv_bs_helper(name, itype, 8, 32, 4, 4) // clang-format on
|
||||
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \
|
||||
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \
|
||||
instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
|
||||
|
||||
instantiate_gemv_bs_blocks(float32, float);
|
||||
instantiate_gemv_bs_blocks(float16, half);
|
||||
@ -553,13 +666,15 @@ instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoNCBatch, /* Batch ndim > 1 */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_t(
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
@ -579,8 +694,9 @@ template <
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
|
||||
threadgroup T tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
if (kDoNCBatch) {
|
||||
@ -613,17 +729,19 @@ template <
|
||||
alpha,
|
||||
beta,
|
||||
bias_stride,
|
||||
tgp_memory,
|
||||
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \
|
||||
"_nc" #nc "_axpby" #axpby)]] [[kernel]] void \
|
||||
gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
|
||||
#define instantiate_gemv_t_helper( \
|
||||
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
|
||||
"_tm" #tm "_tn" #tn "_nc" #nc \
|
||||
"_axpby" #axpby)]] [[kernel]] void \
|
||||
gemv_t<itype, bm, bn, sm, sn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
@ -645,20 +763,19 @@ template <
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_blocks(name, itype) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 128, 4, 4) // clang-format on
|
||||
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 1) \
|
||||
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 1, 4, 8, 4, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 1, 16, 8, 4, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemv_t_blocks(float32, float);
|
||||
@ -667,11 +784,13 @@ instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_t_bs(
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_gather(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
@ -696,8 +815,9 @@ template <
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, false>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false>;
|
||||
threadgroup T tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
uint32_t indx_vec;
|
||||
uint32_t indx_mat;
|
||||
@ -745,50 +865,49 @@ template <
|
||||
alpha,
|
||||
beta,
|
||||
batch_ndim, // Not used,
|
||||
tgp_memory,
|
||||
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_t_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
gemv_t_bs<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* index_batch_strides [[buffer(11)]], \
|
||||
const constant int& vector_batch_ndim [[buffer(12)]], \
|
||||
const constant int* vector_batch_shape [[buffer(13)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]], \
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]], \
|
||||
const constant int* matrix_batch_shape [[buffer(16)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]], \
|
||||
const constant uint32_t* vec_indices [[buffer(18)]], \
|
||||
const constant uint32_t* mat_indices [[buffer(19)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
|
||||
template [[host_name("gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
||||
"_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \
|
||||
gemv_t_gather<itype, bm, bn, sm, sn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* index_batch_strides [[buffer(11)]], \
|
||||
const constant int& vector_batch_ndim [[buffer(12)]], \
|
||||
const constant int* vector_batch_shape [[buffer(13)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]], \
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]], \
|
||||
const constant int* matrix_batch_shape [[buffer(16)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]], \
|
||||
const constant uint32_t* vec_indices [[buffer(18)]], \
|
||||
const constant uint32_t* mat_indices [[buffer(19)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_bs_blocks(name, itype) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 1) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 4) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 16, 4, 4) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 32, 4, 4) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 64, 4, 4) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 128, 4, 4) // clang-format on
|
||||
#define instantiate_gemv_t_bs_blocks(name, itype) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 1, 4, 8, 4, 4, 4) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 1, 16, 8, 4, 4, 4) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemv_t_bs_blocks(float32, float);
|
||||
|
939
mlx/backend/metal/kernels/gemv_masked.metal
Normal file
939
mlx/backend/metal/kernels/gemv_masked.metal
Normal file
@ -0,0 +1,939 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Matrix vector multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
struct _NoMask {
|
||||
char x;
|
||||
|
||||
constexpr METAL_FUNC operator bool() {
|
||||
return true;
|
||||
}
|
||||
constexpr METAL_FUNC operator bool() const threadgroup {
|
||||
return true;
|
||||
}
|
||||
constexpr METAL_FUNC operator bool() const device {
|
||||
return true;
|
||||
}
|
||||
constexpr METAL_FUNC operator bool() const constant {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
typedef struct _NoMask nomask_t;
|
||||
|
||||
template <typename OutT, typename InT = OutT>
|
||||
struct ScaleOp {
|
||||
OutT scale;
|
||||
|
||||
METAL_FUNC OutT apply(InT x) const {
|
||||
return static_cast<OutT>(x) * scale;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename out_mask_t,
|
||||
typename op_mask_t,
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
struct GEMVKernel {
|
||||
MLX_MTL_CONST int threadsM = BM * SM;
|
||||
MLX_MTL_CONST int threadsN = BN * SN;
|
||||
|
||||
MLX_MTL_CONST int blockM = threadsM * TM;
|
||||
MLX_MTL_CONST int blockN = threadsN * TN;
|
||||
|
||||
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
||||
|
||||
static_assert(
|
||||
SN == 8 || SN == 16 || SN == 32,
|
||||
"gemv block must have a width of 8, 16, or 32");
|
||||
|
||||
static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM");
|
||||
|
||||
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||
|
||||
MLX_MTL_CONST bool has_mul_operand_mask =
|
||||
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
||||
MLX_MTL_CONST bool has_mul_output_mask =
|
||||
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
||||
|
||||
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
|
||||
// into blocks of (blockM, blockN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for
|
||||
// the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across
|
||||
// the rows. These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated blockM outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results
|
||||
// remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted
|
||||
// inwards such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
|
||||
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
|
||||
|
||||
static METAL_FUNC void
|
||||
load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
dst[tn] = src[src_offset + tn];
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC void load_safe(
|
||||
const device T* src,
|
||||
thread T dst[TN],
|
||||
const int src_offset = 0,
|
||||
const int src_size = TN) {
|
||||
if (src_offset + TN <= src_size) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
dst[tn] = src[src_offset + tn];
|
||||
}
|
||||
} else { // Edgecase
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& matrix_ld [[buffer(6)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
// Appease compiler
|
||||
(void)lid;
|
||||
|
||||
// Thread local accumulation results
|
||||
thread T result[TM] = {0};
|
||||
thread T inter[TN];
|
||||
thread T v_coeff[TN];
|
||||
|
||||
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
||||
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
||||
|
||||
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
||||
|
||||
const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
|
||||
const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
|
||||
|
||||
int bm = (simdM + thrM) * TM;
|
||||
int bn = (simdN + thrN) * TN;
|
||||
|
||||
// Block position
|
||||
int out_row = tid.x * blockM + bm;
|
||||
|
||||
// Exit simdgroup if rows out of bound
|
||||
if (out_row >= out_vec_size)
|
||||
return;
|
||||
|
||||
// Adjust tail simdgroup to ensure in bound reads
|
||||
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
||||
|
||||
// Prepare mask offsets
|
||||
const constant int* out_mask_strides = mask_strides;
|
||||
const constant int* mat_mask_strides =
|
||||
mask_strides + (has_output_mask ? 2 : 0);
|
||||
const constant int* vec_mask_strides =
|
||||
mat_mask_strides + (has_operand_mask ? 2 : 0);
|
||||
|
||||
const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x);
|
||||
|
||||
const int out_mask_offset =
|
||||
!has_output_mask ? 0 : m_block_idx * out_mask_strides[1];
|
||||
|
||||
int mat_mask_offset =
|
||||
!has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1];
|
||||
int vec_mask_offset = 0;
|
||||
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0];
|
||||
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1];
|
||||
|
||||
T out_scale{1};
|
||||
|
||||
// Check output mask
|
||||
if (has_output_mask) {
|
||||
auto mask_out = out_mask[out_mask_offset];
|
||||
|
||||
// Write zeros and return if mask is 0
|
||||
if (!mask_out) {
|
||||
if (simdN == 0 && thrN == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
out_vec[out_row + tm] = T(0.);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Store scalar if multiplicative mask
|
||||
if (has_mul_output_mask) {
|
||||
out_scale = T(mask_out);
|
||||
}
|
||||
}
|
||||
|
||||
// Advance matrix
|
||||
mat += out_row * matrix_ld;
|
||||
|
||||
// Prepare for loop
|
||||
constexpr const uniform<int> loop_stride = make_uniform(blockN);
|
||||
const uniform<int> in_size = make_uniform(in_vec_size);
|
||||
const uniform<int> n_iter = in_size / loop_stride;
|
||||
const uniform<int> last_iter = loop_stride * n_iter;
|
||||
const uniform<int> leftover = in_size - last_iter;
|
||||
|
||||
// Loop over in_vec in blocks of blockN
|
||||
for (int i = 0; i < n_iter; ++i) {
|
||||
if (!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset]))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
|
||||
load_unsafe(in_vec, v_coeff, bn);
|
||||
|
||||
// Apply scale
|
||||
if (has_mul_operand_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] *= block_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
int mat_offset = 0;
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
load_unsafe(mat, inter, mat_offset + bn);
|
||||
|
||||
// Accumulate results
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
|
||||
mat_offset += matrix_ld;
|
||||
}
|
||||
}
|
||||
|
||||
bn += blockN;
|
||||
mat_mask_offset += mat_mask_step;
|
||||
vec_mask_offset += vec_mask_step;
|
||||
}
|
||||
|
||||
if (leftover > 0 &&
|
||||
(!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset])))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
|
||||
load_safe(in_vec, v_coeff, bn, in_size);
|
||||
|
||||
// Apply scale
|
||||
if (has_mul_operand_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] *= block_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
|
||||
|
||||
// Accumulate results
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply out scale
|
||||
if (has_mul_output_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
result[tm] *= out_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Simdgroup accumulations
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
|
||||
result[tm] += simd_shuffle_down(result[tm], sn);
|
||||
}
|
||||
}
|
||||
|
||||
// Threadgroup accumulation results
|
||||
if (needs_tgp_reduction) {
|
||||
threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
|
||||
if (thrN == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
tgp_results[tm] = result[tm];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (sgN == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int sgn = 1; sgn < BN; sgn++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
result[tm] += tgp_results[sgn * (blockM + TM) + tm];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write outputs
|
||||
if (simdN == 0 && thrN == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
out_vec[out_row + tm] = result[tm];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename out_mask_t,
|
||||
typename op_mask_t,
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
struct GEMVTKernel {
|
||||
MLX_MTL_CONST int threadsM = BM * SM;
|
||||
MLX_MTL_CONST int threadsN = BN * SN;
|
||||
|
||||
MLX_MTL_CONST int blockM = threadsM * TM;
|
||||
MLX_MTL_CONST int blockN = threadsN * TN;
|
||||
|
||||
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
||||
|
||||
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||
|
||||
MLX_MTL_CONST bool has_mul_operand_mask =
|
||||
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
||||
MLX_MTL_CONST bool has_mul_output_mask =
|
||||
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
||||
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
// into blocks of (blockM, blockN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then accumulates its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across
|
||||
// the rows. These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results
|
||||
// remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted
|
||||
// inwards such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
|
||||
MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
// Appease compiler
|
||||
(void)lid;
|
||||
|
||||
// Thread local accumulation results
|
||||
T result[TN] = {0};
|
||||
T inter[TN];
|
||||
T v_coeff[TM];
|
||||
|
||||
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
||||
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
||||
|
||||
const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
|
||||
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
||||
|
||||
const int simdM = SM * sgM;
|
||||
const int simdN = SN * sgN;
|
||||
|
||||
int cm = (simdM + thrM);
|
||||
int cn = (simdN + thrN);
|
||||
|
||||
int bm = cm * TM;
|
||||
int bn = cn * TN;
|
||||
|
||||
int out_col = tid.x * blockN + bn;
|
||||
|
||||
// Prepare mask offsets
|
||||
const constant int* out_mask_strides = mask_strides;
|
||||
const constant int* mat_mask_strides =
|
||||
out_mask_strides + (has_output_mask ? 2 : 0);
|
||||
const constant int* vec_mask_strides =
|
||||
mat_mask_strides + (has_operand_mask ? 2 : 0);
|
||||
|
||||
const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x);
|
||||
|
||||
const int out_mask_offset =
|
||||
!has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0];
|
||||
|
||||
int mat_mask_offset =
|
||||
!has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0];
|
||||
int vec_mask_offset = 0;
|
||||
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1];
|
||||
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0];
|
||||
|
||||
T out_scale{1};
|
||||
|
||||
// Check output mask
|
||||
if (has_output_mask) {
|
||||
auto mask_out = out_mask[out_mask_offset];
|
||||
|
||||
// Write zeros and return if mask is 0
|
||||
if (!mask_out) {
|
||||
if (cm == 0 && out_col < out_vec_size) {
|
||||
if (out_col + TN <= out_vec_size) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
out_vec[out_col + tn] = T(0.);
|
||||
}
|
||||
} else {
|
||||
for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
|
||||
out_vec[out_col + tn] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Store scalar if multiplicative mask
|
||||
if (has_mul_output_mask) {
|
||||
out_scale = T(mask_out);
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare for loop
|
||||
constexpr const uniform<int> loop_stride = make_uniform(blockM);
|
||||
const uniform<int> in_size = make_uniform(in_vec_size);
|
||||
const uniform<int> n_iter = in_size / loop_stride;
|
||||
const uniform<int> last_iter = loop_stride * n_iter;
|
||||
const uniform<int> leftover = in_size - last_iter;
|
||||
|
||||
// Edgecase handling
|
||||
if (out_col < out_vec_size) {
|
||||
out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
|
||||
|
||||
// Per thread accumulation main loop
|
||||
for (int i = 0; i < n_iter; ++i) {
|
||||
// Adding a threadgroup_barrier improves performance slightly
|
||||
// This is possibly it may help exploit cache better
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset]))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
}
|
||||
|
||||
// Apply scale
|
||||
if (has_mul_operand_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] *= block_scale;
|
||||
}
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bm += blockM;
|
||||
mat_mask_offset += mat_mask_step;
|
||||
vec_mask_offset += vec_mask_step;
|
||||
}
|
||||
|
||||
if (leftover > 0 &&
|
||||
(!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset])))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
|
||||
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
|
||||
if (has_mul_operand_mask) {
|
||||
v_coeff[tm] *= block_scale;
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply out scale
|
||||
if (has_mul_output_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] *= out_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Simdgroup accumulations
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
|
||||
result[tn] += simd_shuffle_down(result[tn], SN * sm);
|
||||
}
|
||||
}
|
||||
|
||||
// Threadgroup accumulation results
|
||||
if (needs_tgp_reduction) {
|
||||
threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
|
||||
if (thrM == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
tgp_results[tn] = result[tn];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (sgM == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int sgm = 1; sgm < BM; sgm++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += tgp_results[sgm * (blockN + TN) + tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Threadgroup accumulation and writing out results
|
||||
if (cm == 0 && out_col < out_vec_size) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
out_vec[out_col + j] = result[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Matrix vector multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename out_mask_t,
|
||||
typename op_mask_t,
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoNCBatch> /* Batch ndim > 1 */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel =
|
||||
GEMVKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
|
||||
threadgroup T tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||
|
||||
// Update batch offsets
|
||||
if (kDoNCBatch) {
|
||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||
|
||||
if (has_output_mask) {
|
||||
out_mask +=
|
||||
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
|
||||
mask_batch_strides += batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_mat = mask_batch_strides;
|
||||
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
||||
|
||||
mat_mask += batch_offsets.x;
|
||||
vec_mask += batch_offsets.y;
|
||||
}
|
||||
|
||||
} else {
|
||||
in_vec += tid.z * vector_batch_stride[0];
|
||||
mat += tid.z * matrix_batch_stride[0];
|
||||
|
||||
if (has_output_mask) {
|
||||
out_mask += tid.z * mask_batch_strides[0];
|
||||
mask_batch_strides += batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
mat_mask += tid.z * mask_batch_strides[0];
|
||||
vec_mask += tid.z * mask_batch_strides[batch_ndim];
|
||||
}
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
out_mask,
|
||||
mat_mask,
|
||||
vec_mask,
|
||||
mask_strides,
|
||||
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_helper( \
|
||||
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
template [[host_name("gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc)]] [[kernel]] void \
|
||||
gemv_masked<itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const device outm_t* out_mask [[buffer(20)]], \
|
||||
const device opm_t* mat_mask [[buffer(21)]], \
|
||||
const device opm_t* vec_mask [[buffer(22)]], \
|
||||
const constant int* mask_strides [[buffer(23)]], \
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(bool_, bool, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(name, itype, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
|
||||
instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \
|
||||
instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 1) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 2, 1, 4, 8, 1, 4) \
|
||||
instantiate_gemv(name, itype, 2, 1, 4, 8, 4, 4) \
|
||||
instantiate_gemv(name, itype, 2, 1, 2, 16, 1, 4) \
|
||||
instantiate_gemv(name, itype, 2, 1, 2, 16, 4, 4) \
|
||||
instantiate_gemv(name, itype, 4, 1, 2, 16, 4, 4) // clang-format on
|
||||
|
||||
instantiate_gemv_blocks(float32, float);
|
||||
instantiate_gemv_blocks(float16, half);
|
||||
instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename out_mask_t,
|
||||
typename op_mask_t,
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoNCBatch> /* Batch ndim > 1 */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel =
|
||||
GEMVTKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
|
||||
threadgroup T tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||
|
||||
// Update batch offsets
|
||||
if (kDoNCBatch) {
|
||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||
|
||||
if (has_output_mask) {
|
||||
out_mask +=
|
||||
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
|
||||
mask_batch_strides += batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_mat = mask_batch_strides;
|
||||
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
||||
|
||||
mat_mask += batch_offsets.x;
|
||||
vec_mask += batch_offsets.y;
|
||||
}
|
||||
|
||||
} else {
|
||||
in_vec += tid.z * vector_batch_stride[0];
|
||||
mat += tid.z * matrix_batch_stride[0];
|
||||
|
||||
if (has_output_mask) {
|
||||
out_mask += tid.z * mask_batch_strides[0];
|
||||
mask_batch_strides += batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
mat_mask += tid.z * mask_batch_strides[0];
|
||||
vec_mask += tid.z * mask_batch_strides[batch_ndim];
|
||||
}
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
out_mask,
|
||||
mat_mask,
|
||||
vec_mask,
|
||||
mask_strides,
|
||||
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t_helper( \
|
||||
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc)]] [[kernel]] void \
|
||||
gemv_t_masked<itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const device outm_t* out_mask [[buffer(20)]], \
|
||||
const device opm_t* mat_mask [[buffer(21)]], \
|
||||
const device opm_t* vec_mask [[buffer(22)]], \
|
||||
const constant int* mask_strides [[buffer(23)]], \
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(bool_, bool, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(name, itype, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
|
||||
instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \
|
||||
instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 1) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_blocks(name, itype) \
|
||||
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 4, 1) \
|
||||
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 1) \
|
||||
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 4) \
|
||||
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 8, 4) \
|
||||
instantiate_gemv_t(name, itype, 1, 4, 8, 4, 8, 4) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemv_t_blocks(float32, float);
|
||||
instantiate_gemv_t_blocks(float16, half);
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
|
@ -786,38 +786,47 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int bm, bn, n_out_per_tgp;
|
||||
int sm = 1, sn = 32;
|
||||
int bm = 1, bn = 1;
|
||||
int n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
bm = 8;
|
||||
bn = 8;
|
||||
if (out_vector_len >= 24576) {
|
||||
bn = 128;
|
||||
} else if (out_vector_len >= 16384) {
|
||||
bn = 64;
|
||||
} else if (out_vector_len >= 8192) {
|
||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
||||
sm = 4;
|
||||
sn = 8;
|
||||
} else {
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
}
|
||||
|
||||
if (out_vector_len >= 2048) {
|
||||
bn = 16;
|
||||
} else if (out_vector_len >= 512) {
|
||||
bn = 4;
|
||||
} else {
|
||||
bn = 2;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * tn;
|
||||
n_out_per_tgp = bn * sn * tn;
|
||||
kname << "gemv_t_" << type_to_name(out);
|
||||
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
bn = 32;
|
||||
sn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * tm;
|
||||
n_out_per_tgp = bm * sm * tm;
|
||||
kname << "gemv_" << type_to_name(out);
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||
<< tm << "_tn" << tn;
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby0";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
@ -826,7 +835,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
@ -838,11 +847,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
@ -910,15 +917,19 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1) {
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1) {
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
@ -929,12 +940,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
|
||||
array c = c_pre;
|
||||
int ldc = c.strides()[c.ndim() - 2];
|
||||
@ -997,38 +1004,47 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int bm, bn, n_out_per_tgp;
|
||||
int sm = 1, sn = 32;
|
||||
int bm = 1, bn = 1;
|
||||
int n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
bm = 8;
|
||||
bn = 8;
|
||||
if (out_vector_len >= 24576) {
|
||||
bn = 128;
|
||||
} else if (out_vector_len >= 16384) {
|
||||
bn = 64;
|
||||
} else if (out_vector_len >= 8192) {
|
||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
||||
sm = 4;
|
||||
sn = 8;
|
||||
} else {
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
}
|
||||
|
||||
if (out_vector_len >= 2048) {
|
||||
bn = 16;
|
||||
} else if (out_vector_len >= 512) {
|
||||
bn = 4;
|
||||
} else {
|
||||
bn = 2;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * tn;
|
||||
n_out_per_tgp = bn * sn * tn;
|
||||
kname << "gemv_t_" << type_to_name(out);
|
||||
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
bn = 32;
|
||||
sn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * tm;
|
||||
n_out_per_tgp = bm * sm * tm;
|
||||
kname << "gemv_" << type_to_name(out);
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||
<< tm << "_tn" << tn;
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby1";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
@ -1037,7 +1053,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
@ -1344,15 +1360,19 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1) {
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1) {
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
@ -1363,33 +1383,38 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
bool has_op_mask = inputs.size() > 3;
|
||||
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
|
||||
|
||||
// Prepare kernel name
|
||||
std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
|
||||
std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask";
|
||||
|
||||
auto get_batch_dims = [](const auto& v) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
};
|
||||
|
||||
std::vector<int> batch_shape{1};
|
||||
std::vector<size_t> A_batch_stride{0};
|
||||
std::vector<size_t> B_batch_stride{0};
|
||||
std::vector<size_t> outmask_bstride{0};
|
||||
std::vector<size_t> Amask_bstride{0};
|
||||
std::vector<size_t> Bmask_bstride{0};
|
||||
size_t A_batch_str = 0;
|
||||
size_t B_batch_str = 0;
|
||||
|
||||
std::vector<size_t> batch_strides;
|
||||
|
||||
if (out.ndim() > 2) {
|
||||
auto get_batch_dims = [](const auto& v) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
};
|
||||
|
||||
std::vector<int> bshape{out.shape().begin(), out.shape().end() - 2};
|
||||
std::vector<std::vector<size_t>> bstrides;
|
||||
|
||||
@ -1397,14 +1422,26 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2);
|
||||
}
|
||||
|
||||
auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);
|
||||
batch_shape = bshape_c;
|
||||
A_batch_str = bstrides_c[0].back();
|
||||
B_batch_str = bstrides_c[1].back();
|
||||
// auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);
|
||||
batch_shape = bshape;
|
||||
A_batch_str = bstrides[0].back();
|
||||
B_batch_str = bstrides[1].back();
|
||||
|
||||
for (auto& bstr : bstrides_c) {
|
||||
for (auto& bstr : bstrides) {
|
||||
batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end());
|
||||
}
|
||||
|
||||
A_batch_stride = bstrides[0];
|
||||
B_batch_stride = bstrides[1];
|
||||
|
||||
if (has_out_mask) {
|
||||
outmask_bstride = bstrides[2];
|
||||
}
|
||||
if (has_op_mask) {
|
||||
Amask_bstride = bstrides[has_out_mask + 2];
|
||||
Bmask_bstride = bstrides[has_out_mask + 3];
|
||||
}
|
||||
|
||||
} else {
|
||||
batch_strides = std::vector<size_t>(inputs.size(), 0);
|
||||
}
|
||||
@ -1412,6 +1449,174 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
size_t matrix_stride_out = size_t(M) * N;
|
||||
size_t batch_size_out = out.size() / (matrix_stride_out);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemv specialization
|
||||
|
||||
// Route to gemv if needed
|
||||
if (std::min(M, N) == 1) {
|
||||
// Collect problem info
|
||||
bool is_b_matrix = N != 1;
|
||||
|
||||
auto& mat = is_b_matrix ? b : a;
|
||||
auto& vec = is_b_matrix ? a : b;
|
||||
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
|
||||
int in_vector_len = K;
|
||||
int out_vector_len = is_b_matrix ? N : M;
|
||||
|
||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
||||
|
||||
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
|
||||
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
|
||||
|
||||
auto mask_bstrides_mat = is_b_matrix ? Bmask_bstride : Amask_bstride;
|
||||
auto mask_bstrides_vec = is_b_matrix ? Amask_bstride : Bmask_bstride;
|
||||
|
||||
auto mat_mask_idx = int(has_out_mask) + (is_b_matrix ? 3 : 2);
|
||||
auto vec_mask_idx = int(has_out_mask) + (is_b_matrix ? 2 : 3);
|
||||
|
||||
// Determine if inputs have simple batching / broadcasting
|
||||
bool contiguous_kernel = (batch_shape.size() == 1);
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int sm = 1, sn = 32;
|
||||
int bm = 1, bn = 1;
|
||||
int n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
bm = 1;
|
||||
bn = (block_size_ == 64 && out_vector_len >= 2048) ? 4 : 2;
|
||||
tm = block_size_ == 32 ? 4 : 8;
|
||||
tn = 4;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * sn * tn;
|
||||
kname << "gemv_t";
|
||||
|
||||
} else {
|
||||
if (block_size_ == 32) {
|
||||
sm = 4;
|
||||
sn = 8;
|
||||
bm = 2;
|
||||
} else {
|
||||
sm = 2;
|
||||
sn = 16;
|
||||
bm = out_vector_len >= 512 ? 4 : 2;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * sm * tm;
|
||||
kname << "gemv";
|
||||
}
|
||||
|
||||
kname << "_outmask_" << out_mask_nm;
|
||||
kname << "_opmask_" << op_mask_nm;
|
||||
kname << "_" << type_to_name(out);
|
||||
kname << "_bm" << bm << "_bn" << bn;
|
||||
kname << "_sm" << sm << "_sn" << sn;
|
||||
kname << "_tm" << tm << "_tn" << tn;
|
||||
kname << "_nc" << !contiguous_kernel;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
// Get mask params
|
||||
std::vector<int> mask_strides;
|
||||
std::vector<size_t> mask_batch_strides;
|
||||
if (has_out_mask) {
|
||||
auto& out_mask = inputs[2];
|
||||
|
||||
if (transpose_mat) {
|
||||
mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -1 : -2));
|
||||
mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -2 : -1));
|
||||
} else {
|
||||
mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -1 : -2));
|
||||
mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -2 : -1));
|
||||
}
|
||||
|
||||
mask_batch_strides.insert(
|
||||
mask_batch_strides.end(),
|
||||
outmask_bstride.begin(),
|
||||
outmask_bstride.end());
|
||||
|
||||
compute_encoder.set_input_array(out_mask, 20);
|
||||
}
|
||||
|
||||
if (has_op_mask) {
|
||||
auto& mat_mask = inputs[mat_mask_idx];
|
||||
|
||||
if (transpose_mat) {
|
||||
mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -2 : -1));
|
||||
mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -1 : -2));
|
||||
} else {
|
||||
mask_strides.push_back(mat_mask.strides(is_b_matrix ? -2 : -1));
|
||||
mask_strides.push_back(mat_mask.strides(is_b_matrix ? -1 : -2));
|
||||
}
|
||||
|
||||
mask_batch_strides.insert(
|
||||
mask_batch_strides.end(),
|
||||
mask_bstrides_mat.begin(),
|
||||
mask_bstrides_mat.end());
|
||||
|
||||
compute_encoder.set_input_array(mat_mask, 21);
|
||||
|
||||
auto& vec_mask = inputs[vec_mask_idx];
|
||||
if (transpose_mat) {
|
||||
mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -1 : -2));
|
||||
mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -2 : -1));
|
||||
} else {
|
||||
mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -1 : -2));
|
||||
mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -2 : -1));
|
||||
}
|
||||
|
||||
mask_batch_strides.insert(
|
||||
mask_batch_strides.end(),
|
||||
mask_bstrides_vec.begin(),
|
||||
mask_bstrides_vec.end());
|
||||
|
||||
compute_encoder.set_input_array(vec_mask, 22);
|
||||
}
|
||||
|
||||
// Get gemv params
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
|
||||
|
||||
set_vector_bytes(compute_encoder, mask_strides, 23);
|
||||
set_vector_bytes(compute_encoder, mask_batch_strides, 24);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
|
||||
@ -1421,10 +1626,6 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||
bool k_aligned = K % bk == 0;
|
||||
|
||||
// Prepare kernel name
|
||||
std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
|
||||
std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask";
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_block_outmask_" << out_mask_nm << "_opmask_"
|
||||
<< op_mask_nm << "_" << (transpose_a ? 't' : 'n')
|
||||
@ -1554,15 +1755,19 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1) {
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1) {
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
@ -1573,16 +1778,12 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
@ -1673,38 +1874,47 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int bm, bn, n_out_per_tgp;
|
||||
int sm = 1, sn = 32;
|
||||
int bm = 1, bn = 1;
|
||||
int n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
bm = 8;
|
||||
bn = 8;
|
||||
if (out_vector_len >= 24576) {
|
||||
bn = 128;
|
||||
} else if (out_vector_len >= 16384) {
|
||||
bn = 64;
|
||||
} else if (out_vector_len >= 8192) {
|
||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
||||
sm = 4;
|
||||
sn = 8;
|
||||
} else {
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
}
|
||||
|
||||
if (out_vector_len >= 2048) {
|
||||
bn = 16;
|
||||
} else if (out_vector_len >= 512) {
|
||||
bn = 4;
|
||||
} else {
|
||||
bn = 2;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * tn;
|
||||
kname << "gemv_t_bs_" << type_to_name(out);
|
||||
n_out_per_tgp = bn * sn * tn;
|
||||
kname << "gemv_t_gather_" << type_to_name(out);
|
||||
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
bn = 32;
|
||||
sn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * tm;
|
||||
kname << "gemv_bs_" << type_to_name(out);
|
||||
n_out_per_tgp = bm * sm * tm;
|
||||
kname << "gemv_gather_" << type_to_name(out);
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||
<< tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@ -1712,7 +1922,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
|
@ -3634,16 +3634,19 @@ std::vector<array> BlockMaskedMM::vjp(
|
||||
};
|
||||
|
||||
// Prepare for padding if needed
|
||||
int M = cotan.shape(-2);
|
||||
int N = cotan.shape(-1);
|
||||
int K = primals[0].shape(-1);
|
||||
int align_M = (M % block_size_);
|
||||
int align_N = (N % block_size_);
|
||||
int align_K = (K % block_size_);
|
||||
const int M = cotan.shape(-2);
|
||||
const int N = cotan.shape(-1);
|
||||
const int K = primals[0].shape(-1);
|
||||
const int tm = (M + block_size_ - 1) / block_size_;
|
||||
const int tn = (N + block_size_ - 1) / block_size_;
|
||||
const int tk = (K + block_size_ - 1) / block_size_;
|
||||
const int align_M = tm * block_size_ - M;
|
||||
const int align_N = tn * block_size_ - N;
|
||||
const int align_K = tk * block_size_ - K;
|
||||
|
||||
// Potential intermediates
|
||||
auto unmasked_lhs_grad = primals[0];
|
||||
auto unmasked_rhs_grad = primals[1];
|
||||
array unmasked_lhs_grad = primals[0];
|
||||
array unmasked_rhs_grad = primals[1];
|
||||
|
||||
bool unmasked_lhs_grad_calculated = false;
|
||||
bool unmasked_rhs_grad_calculated = false;
|
||||
|
@ -263,7 +263,9 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
mlx_mat_f=lambda x: x,
|
||||
mlx_vec_f=lambda x: x,
|
||||
):
|
||||
with self.subTest(shape=shape_mat):
|
||||
with self.subTest(
|
||||
shape_mat=shape_mat, shape_vec=shape_vec, mat_first=mat_first
|
||||
):
|
||||
np.random.seed(42)
|
||||
scale = max(np.sum(shape_mat), 32)
|
||||
mat_npy = np.random.normal(0.0, 1.0 / scale, shape_mat).astype(np_dtype)
|
||||
@ -794,10 +796,12 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
out_ref, dout_ref = mx.vjp(f_ref, [a, b], [cotan])
|
||||
out_test, dout_test = mx.vjp(f_test, [a, b], [cotan])
|
||||
|
||||
mx.eval((out_ref, dout_ref, out_test, dout_test))
|
||||
|
||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item())
|
||||
|
||||
for r, t in zip(dout_ref, dout_test):
|
||||
self.assertEqual(r.shape, t.shape)
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||
|
||||
def run_test_mask_vjp(a, b, block_size, out_mask, a_mask, b_mask, cotan):
|
||||
def f_ref(a_, b_, a_mask_, b_mask_):
|
||||
return ref_block_masked_mm(
|
||||
@ -896,6 +900,8 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
(64, 64, 16, 32),
|
||||
(128, 128, 128, 32),
|
||||
(256, 256, 128, 64),
|
||||
(1, 128, 128, 32),
|
||||
(256, 1, 128, 64),
|
||||
)
|
||||
|
||||
for M, N, K, block_size in shapes:
|
||||
@ -903,21 +909,51 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
|
||||
# Test broadcasting
|
||||
test_shape(64, 64, 64, 32, batch_A=(1, 2), batch_B=(2, 2))
|
||||
test_shape(1, 128, 128, 32, batch_A=(1, 2), batch_B=(2, 2))
|
||||
test_shape(128, 1, 128, 32, batch_A=(1, 2), batch_B=(2, 2))
|
||||
|
||||
# Test gemv
|
||||
a_np = np.random.normal(size=(64, 64)).astype(np.float32)
|
||||
b_np = np.random.normal(size=(64,)).astype(np.float32)
|
||||
mask_np = np.array([True, False]).astype(np.bool_)
|
||||
a_np = np.ones((128, 256)).astype(np.float32)
|
||||
b_np = np.ones((128, 1)).astype(np.float32)
|
||||
d_np = np.ones((1, 256)).astype(np.float32)
|
||||
a_mask_np = np.random.normal(size=(4, 8)).astype(np.float32)
|
||||
b_mask_np = np.ones((4, 1)).astype(np.bool_)
|
||||
d_mask_np = np.ones((1, 8)).astype(np.bool_)
|
||||
c_mask_np = np.random.normal(size=(8, 1)).astype(np.float32)
|
||||
e_mask_np = np.random.normal(size=(1, 4)).astype(np.float32)
|
||||
|
||||
a_mask_np[a_mask_np < 0.0] = 0.0
|
||||
e_mask_np[e_mask_np < 0.0] = 0.0
|
||||
c_mask_np[c_mask_np < 0.0] = 0.0
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
mask_mx = mx.array(mask_np)
|
||||
d_mx = mx.array(d_np)
|
||||
a_mask_mx = mx.array(a_mask_np)
|
||||
b_mask_mx = mx.array(b_mask_np)
|
||||
d_mask_mx = mx.array(d_mask_np)
|
||||
e_mask_mx = mx.array(e_mask_np)
|
||||
c_mask_mx = mx.array(c_mask_np)
|
||||
|
||||
c_mx = mx.block_masked_mm(a_mx, b_mx, 32, mask_mx)
|
||||
c_np = a_np @ b_np
|
||||
c_np[32:] = 0.0
|
||||
c_mx = mx.block_masked_mm(a_mx.T, b_mx, 32, c_mask_mx, a_mask_mx.T, b_mask_mx)
|
||||
e_mx = mx.block_masked_mm(d_mx, a_mx.T, 32, e_mask_mx, d_mask_mx, a_mask_mx.T)
|
||||
|
||||
a_mask_np = np.broadcast_to(np.expand_dims(a_mask_np, (-3, -1)), (4, 32, 8, 32))
|
||||
a_mask_np = a_mask_np.reshape((128, 256))
|
||||
a_np *= a_mask_np
|
||||
|
||||
c_np = a_np.T @ b_np
|
||||
e_np = d_np @ a_np.T
|
||||
|
||||
c_mask_np = np.broadcast_to(np.expand_dims(c_mask_np, (-2)), (8, 32, 1))
|
||||
c_mask_np = c_mask_np.reshape((256, 1))
|
||||
c_np *= c_mask_np
|
||||
|
||||
e_mask_np = np.broadcast_to(np.expand_dims(e_mask_np, (-1)), (1, 4, 32))
|
||||
e_mask_np = e_mask_np.reshape((1, 128))
|
||||
e_np *= e_mask_np
|
||||
|
||||
self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5))
|
||||
self.assertTrue(np.allclose(e_mx, e_np, atol=1e-5))
|
||||
|
||||
def test_gather_matmul(self):
|
||||
def np_gather_mm(a, b, lhs_indices=None, rhs_indices=None):
|
||||
|
Loading…
Reference in New Issue
Block a user