mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 13:51:13 +08:00
848 lines
30 KiB
Metal
848 lines
30 KiB
Metal
// Copyright © 2023-2024 Apple Inc.
|
|
|
|
#include <metal_simdgroup>
|
|
#include <metal_stdlib>
|
|
|
|
#include "mlx/backend/metal/kernels/utils.h"
|
|
|
|
#include "mlx/backend/metal/kernels/steel/utils.h"
|
|
|
|
using namespace metal;
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
/// Matrix vector multiplication
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
#define MLX_MTL_CONST static constant constexpr const
|
|
|
|
template <
|
|
typename T,
|
|
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
const int SM, /* Simdgroup rows (in threads) */
|
|
const int SN, /* Simdgroup cols (in threads) */
|
|
const int TM, /* Thread rows (in elements) */
|
|
const int TN, /* Thread cols (in elements) */
|
|
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
|
|
typename AccT = float>
|
|
struct GEMVKernel {
|
|
MLX_MTL_CONST int threadsM = BM * SM;
|
|
MLX_MTL_CONST int threadsN = BN * SN;
|
|
|
|
MLX_MTL_CONST int blockM = threadsM * TM;
|
|
MLX_MTL_CONST int blockN = threadsN * TN;
|
|
|
|
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
|
|
|
static_assert(
|
|
SN == 8 || SN == 16 || SN == 32,
|
|
"gemv block must have a width of 8, 16, or 32");
|
|
|
|
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
|
|
// into blocks of (blockM, blockN) divided among threadgroups
|
|
// - Every thread works on a block of (TM, TN)
|
|
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
|
//
|
|
// 1. A thread loads TN elements each from mat along TM rows
|
|
// and the corresponding scalar from the vector
|
|
// 2. The thread then multiplies and adds to accumulate its local result for
|
|
// the block
|
|
// 3. At the end, each thread has accumulated results over all blocks across
|
|
// the rows. These are then summed up across the threadgroup
|
|
// 4. Each threadgroup writes its accumulated blockM outputs
|
|
//
|
|
// Edge case handling:
|
|
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
|
// * The blocks that start outside the matrix are never read (thread results
|
|
// remain zero)
|
|
// * The last thread that partially overlaps with the matrix is shifted
|
|
// inwards such that the thread block fits exactly in the matrix
|
|
|
|
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
|
|
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
|
|
|
|
template <typename U = T>
|
|
static METAL_FUNC void
|
|
load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) {
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tn = 0; tn < TN; tn++) {
|
|
dst[tn] = static_cast<U>(src[src_offset + tn]);
|
|
}
|
|
}
|
|
|
|
template <typename U = T>
|
|
static METAL_FUNC void load_safe(
|
|
const device T* src,
|
|
thread U dst[TN],
|
|
const int src_offset = 0,
|
|
const int src_size = TN) {
|
|
if (src_offset + TN <= src_size) {
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tn = 0; tn < TN; tn++) {
|
|
dst[tn] = static_cast<U>(src[src_offset + tn]);
|
|
}
|
|
} else { // Edgecase
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tn = 0; tn < TN; tn++) {
|
|
dst[tn] = src_offset + tn < src_size
|
|
? static_cast<U>(src[src_offset + tn])
|
|
: U(0);
|
|
}
|
|
}
|
|
}
|
|
|
|
static METAL_FUNC void run(
|
|
const device T* mat [[buffer(0)]],
|
|
const device T* in_vec [[buffer(1)]],
|
|
const device T* bias [[buffer(2)]],
|
|
device T* out_vec [[buffer(3)]],
|
|
const constant int& in_vec_size [[buffer(4)]],
|
|
const constant int& out_vec_size [[buffer(5)]],
|
|
const constant int& matrix_ld [[buffer(6)]],
|
|
const constant float& alpha [[buffer(7)]],
|
|
const constant float& beta [[buffer(8)]],
|
|
const constant int& bias_stride [[buffer(14)]],
|
|
threadgroup AccT* tgp_memory [[threadgroup(0)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint3 lid [[thread_position_in_threadgroup]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
// Appease compiler
|
|
(void)lid;
|
|
|
|
// Thread local accumulation results
|
|
thread AccT result[TM] = {0};
|
|
thread T inter[TN];
|
|
thread AccT v_coeff[TN];
|
|
|
|
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
|
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
|
|
|
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
|
|
|
const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
|
|
const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
|
|
|
|
int bm = (simdM + thrM) * TM;
|
|
int bn = (simdN + thrN) * TN;
|
|
|
|
// Block position
|
|
int out_row = tid.x * blockM + bm;
|
|
|
|
// Exit simdgroup if rows out of bound
|
|
if (out_row >= out_vec_size)
|
|
return;
|
|
|
|
// Adjust tail simdgroup to ensure in bound reads
|
|
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
|
|
|
// Advance matrix
|
|
mat += out_row * matrix_ld;
|
|
|
|
constexpr const uniform<int> loop_stride = make_uniform(blockN);
|
|
const uniform<int> in_size = make_uniform(in_vec_size);
|
|
const uniform<int> n_iter = in_size / loop_stride;
|
|
const uniform<int> last_iter = loop_stride * n_iter;
|
|
const uniform<int> leftover = in_size - last_iter;
|
|
|
|
// Loop over in_vec in blocks of blockN
|
|
for (int i = 0; i < n_iter; ++i) {
|
|
load_unsafe<AccT>(in_vec, v_coeff, bn);
|
|
|
|
// Per thread work loop
|
|
int mat_offset = 0;
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tm = 0; tm < TM; tm++) {
|
|
// Load for the row
|
|
load_unsafe(mat, inter, mat_offset + bn);
|
|
|
|
// Accumulate results
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tn = 0; tn < TN; tn++) {
|
|
result[tm] += inter[tn] * v_coeff[tn];
|
|
}
|
|
|
|
mat_offset += matrix_ld;
|
|
}
|
|
|
|
bn += blockN;
|
|
}
|
|
|
|
if (leftover > 0) {
|
|
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
|
|
|
|
// Per thread work loop
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tm = 0; tm < TM; tm++) {
|
|
// Load for the row
|
|
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
|
|
|
|
// Accumulate results
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tn = 0; tn < TN; tn++) {
|
|
result[tm] += inter[tn] * v_coeff[tn];
|
|
}
|
|
}
|
|
}
|
|
|
|
// Simdgroup accumulations
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tm = 0; tm < TM; tm++) {
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
|
|
result[tm] += simd_shuffle_down(result[tm], sn);
|
|
}
|
|
}
|
|
|
|
// Threadgroup accumulation results
|
|
if (needs_tgp_reduction) {
|
|
threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
|
|
if (thrN == 0) {
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tm = 0; tm < TM; tm++) {
|
|
tgp_results[tm] = result[tm];
|
|
}
|
|
|
|
threadgroup_barrier(mem_flags::mem_none);
|
|
|
|
if (sgN == 0) {
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int sgn = 1; sgn < BN; sgn++) {
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tm = 0; tm < TM; tm++) {
|
|
result[tm] += tgp_results[sgn * (blockM + TM) + tm];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Write outputs
|
|
if (simdN == 0 && thrN == 0) {
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tm = 0; tm < TM; tm++) {
|
|
if (kDoAxpby) {
|
|
out_vec[out_row + tm] =
|
|
static_cast<T>(alpha) * static_cast<T>(result[tm]) +
|
|
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
|
|
} else {
|
|
out_vec[out_row + tm] = static_cast<T>(result[tm]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
/// Vector matrix multiplication
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
typename T,
|
|
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
const int SM, /* Simdgroup rows (in threads) */
|
|
const int SN, /* Simdgroup cols (in threads) */
|
|
const int TM, /* Thread rows (in elements) */
|
|
const int TN, /* Thread cols (in elements) */
|
|
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
|
|
typename AccT = float>
|
|
struct GEMVTKernel {
|
|
MLX_MTL_CONST int threadsM = BM * SM;
|
|
MLX_MTL_CONST int threadsN = BN * SN;
|
|
|
|
MLX_MTL_CONST int blockM = threadsM * TM;
|
|
MLX_MTL_CONST int blockN = threadsN * TN;
|
|
|
|
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
|
|
|
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
|
// into blocks of (blockM, blockN) divided among threadgroups
|
|
// - Every thread works on a block of (TM, TN)
|
|
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
|
//
|
|
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
|
// and the corresponding scalar from the vector
|
|
// 2. The thread then accumulates its local result for the block
|
|
// 3. At the end, each thread has accumulated results over all blocks across
|
|
// the rows. These are then summed up across the threadgroup
|
|
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
|
//
|
|
// Edge case handling:
|
|
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
|
// * The blocks that start outside the matrix are never read (thread results
|
|
// remain zero)
|
|
// * The last thread that partially overlaps with the matrix is shifted
|
|
// inwards such that the thread block fits exactly in the matrix
|
|
|
|
MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
|
|
MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
|
|
|
|
static METAL_FUNC void run(
|
|
const device T* mat [[buffer(0)]],
|
|
const device T* in_vec [[buffer(1)]],
|
|
const device T* bias [[buffer(2)]],
|
|
device T* out_vec [[buffer(3)]],
|
|
const constant int& in_vec_size [[buffer(4)]],
|
|
const constant int& out_vec_size [[buffer(5)]],
|
|
const constant int& marix_ld [[buffer(6)]],
|
|
const constant float& alpha [[buffer(7)]],
|
|
const constant float& beta [[buffer(8)]],
|
|
const constant int& bias_stride [[buffer(14)]],
|
|
threadgroup AccT* tgp_memory [[threadgroup(0)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint3 lid [[thread_position_in_threadgroup]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
// Appease compiler
|
|
(void)lid;
|
|
|
|
// Thread local accumulation results
|
|
AccT result[TN] = {0};
|
|
T inter[TN];
|
|
AccT v_coeff[TM];
|
|
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
|
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
|
|
|
const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
|
|
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
|
|
|
const int simdM = SM * sgM;
|
|
const int simdN = SN * sgN;
|
|
|
|
int cm = (simdM + thrM);
|
|
int cn = (simdN + thrN);
|
|
|
|
int bm = cm * TM;
|
|
int bn = cn * TN;
|
|
|
|
int out_col = tid.x * blockN + bn;
|
|
|
|
constexpr const uniform<int> loop_stride = make_uniform(blockM);
|
|
const uniform<int> in_size = make_uniform(in_vec_size);
|
|
const uniform<int> n_iter = in_size / loop_stride;
|
|
const uniform<int> last_iter = loop_stride * n_iter;
|
|
const uniform<int> leftover = in_size - last_iter;
|
|
|
|
// Edgecase handling
|
|
if (out_col < out_vec_size) {
|
|
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
|
|
|
|
// Per thread accumulation main loop
|
|
for (int i = 0; i < n_iter; ++i) {
|
|
// Adding a threadgroup_barrier improves performance slightly
|
|
// This is possibly it may help exploit cache better
|
|
threadgroup_barrier(mem_flags::mem_none);
|
|
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tm = 0; tm < TM; tm++) {
|
|
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
|
|
}
|
|
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tm = 0; tm < TM; tm++) {
|
|
auto vc = static_cast<AccT>(v_coeff[tm]);
|
|
for (int tn = 0; tn < TN; tn++) {
|
|
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
|
}
|
|
for (int tn = 0; tn < TN; tn++) {
|
|
result[tn] += vc * inter[tn];
|
|
}
|
|
}
|
|
|
|
bm += blockM;
|
|
}
|
|
|
|
if (leftover > 0) {
|
|
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
|
|
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
|
|
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tn = 0; tn < TN; tn++) {
|
|
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
|
}
|
|
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tn = 0; tn < TN; tn++) {
|
|
result[tn] += v_coeff[tm] * inter[tn];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Simdgroup accumulations
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tn = 0; tn < TN; tn++) {
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
|
|
result[tn] += simd_shuffle_down(result[tn], SN * sm);
|
|
}
|
|
}
|
|
|
|
// Threadgroup accumulation results
|
|
if (needs_tgp_reduction) {
|
|
threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
|
|
if (thrM == 0) {
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tn = 0; tn < TN; tn++) {
|
|
tgp_results[tn] = result[tn];
|
|
}
|
|
|
|
threadgroup_barrier(mem_flags::mem_none);
|
|
|
|
if (sgM == 0) {
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int sgm = 1; sgm < BM; sgm++) {
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int tn = 0; tn < TN; tn++) {
|
|
result[tn] += tgp_results[sgm * (blockN + TN) + tn];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Threadgroup accumulation and writing out results
|
|
if (cm == 0 && out_col < out_vec_size) {
|
|
MLX_MTL_PRAGMA_UNROLL
|
|
for (int j = 0; j < TN; j++) {
|
|
if (kDoAxpby) {
|
|
out_vec[out_col + j] =
|
|
static_cast<T>(alpha) * static_cast<T>(result[j]) +
|
|
static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
|
|
} else {
|
|
out_vec[out_col + j] = static_cast<T>(result[j]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
/// Matrix vector multiplication
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
typename T,
|
|
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
const int SM, /* Simdgroup rows (in threads) */
|
|
const int SN, /* Simdgroup cols (in threads) */
|
|
const int TM, /* Thread rows (in elements) */
|
|
const int TN, /* Thread cols (in elements) */
|
|
const bool kDoNCBatch, /* Batch ndim > 1 */
|
|
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
|
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv(
|
|
const device T* mat [[buffer(0)]],
|
|
const device T* in_vec [[buffer(1)]],
|
|
const device T* bias [[buffer(2)]],
|
|
device T* out_vec [[buffer(3)]],
|
|
const constant int& in_vec_size [[buffer(4)]],
|
|
const constant int& out_vec_size [[buffer(5)]],
|
|
const constant int& marix_ld [[buffer(6)]],
|
|
const constant float& alpha [[buffer(7)]],
|
|
const constant float& beta [[buffer(8)]],
|
|
const constant int& batch_ndim [[buffer(9)]],
|
|
const constant int* batch_shape [[buffer(10)]],
|
|
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
|
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
|
const constant int64_t* bias_batch_stride [[buffer(13)]],
|
|
const constant int& bias_stride [[buffer(14)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint3 lid [[thread_position_in_threadgroup]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
|
|
threadgroup float tgp_memory
|
|
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
|
|
|
// Update batch offsets
|
|
if (kDoNCBatch) {
|
|
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
|
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
|
|
|
if (kDoAxpby) {
|
|
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
|
|
}
|
|
|
|
} else {
|
|
in_vec += tid.z * vector_batch_stride[0];
|
|
mat += tid.z * matrix_batch_stride[0];
|
|
|
|
if (kDoAxpby) {
|
|
bias += tid.z * bias_batch_stride[0];
|
|
}
|
|
}
|
|
|
|
out_vec += tid.z * out_vec_size;
|
|
|
|
gemv_kernel::run(
|
|
mat,
|
|
in_vec,
|
|
bias,
|
|
out_vec,
|
|
in_vec_size,
|
|
out_vec_size,
|
|
marix_ld,
|
|
alpha,
|
|
beta,
|
|
bias_stride,
|
|
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
|
tid,
|
|
lid,
|
|
simd_gid,
|
|
simd_lid);
|
|
}
|
|
|
|
#define instantiate_gemv_helper( \
|
|
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
|
instantiate_kernel( \
|
|
"gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
|
"_tn" #tn "_nc" #nc "_axpby" #axpby, \
|
|
gemv, \
|
|
itype, \
|
|
bm, \
|
|
bn, \
|
|
sm, \
|
|
sn, \
|
|
tm, \
|
|
tn, \
|
|
nc, \
|
|
axpby)
|
|
|
|
// clang-format off
|
|
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
|
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 0) \
|
|
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 1) \
|
|
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 0) \
|
|
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 1) // clang-format on
|
|
|
|
// clang-format off
|
|
#define instantiate_gemv_blocks(name, itype) \
|
|
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
|
instantiate_gemv(name, itype, 4, 32, 4, 4) \
|
|
instantiate_gemv(name, itype, 8, 32, 4, 4) // clang-format on
|
|
|
|
instantiate_gemv_blocks(float32, float);
|
|
instantiate_gemv_blocks(float16, half);
|
|
instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
|
|
|
template <
|
|
typename T,
|
|
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
const int SM, /* Simdgroup rows (in threads) */
|
|
const int SN, /* Simdgroup cols (in threads) */
|
|
const int TM, /* Thread rows (in elements) */
|
|
const int TN> /* Thread cols (in elements) */
|
|
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_gather(
|
|
const device T* mat [[buffer(0)]],
|
|
const device T* in_vec [[buffer(1)]],
|
|
const device T* bias [[buffer(2)]],
|
|
device T* out_vec [[buffer(3)]],
|
|
const constant int& in_vec_size [[buffer(4)]],
|
|
const constant int& out_vec_size [[buffer(5)]],
|
|
const constant int& marix_ld [[buffer(6)]],
|
|
const constant float& alpha [[buffer(7)]],
|
|
const constant float& beta [[buffer(8)]],
|
|
const constant int& batch_ndim [[buffer(9)]],
|
|
const constant int* batch_shape [[buffer(10)]],
|
|
const constant int64_t* index_batch_strides [[buffer(11)]],
|
|
const constant int& vector_batch_ndim [[buffer(12)]],
|
|
const constant int* vector_batch_shape [[buffer(13)]],
|
|
const constant int64_t* vector_batch_stride [[buffer(14)]],
|
|
const constant int& matrix_batch_ndim [[buffer(15)]],
|
|
const constant int* matrix_batch_shape [[buffer(16)]],
|
|
const constant int64_t* matrix_batch_stride [[buffer(17)]],
|
|
const constant uint32_t* vec_indices [[buffer(18)]],
|
|
const constant uint32_t* mat_indices [[buffer(19)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint3 lid [[thread_position_in_threadgroup]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, false>;
|
|
threadgroup float tgp_memory
|
|
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
|
|
|
uint32_t indx_vec;
|
|
uint32_t indx_mat;
|
|
|
|
// Update batch offsets
|
|
if (batch_ndim > 1) {
|
|
const constant auto* veci_bstrides = index_batch_strides;
|
|
const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
|
|
|
|
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
|
|
|
|
indx_vec = vec_indices[batch_offsets.x];
|
|
indx_mat = mat_indices[batch_offsets.y];
|
|
|
|
} else {
|
|
indx_vec = vec_indices[index_batch_strides[0] * tid.z];
|
|
indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
|
|
}
|
|
|
|
if (vector_batch_ndim > 1) {
|
|
in_vec += elem_to_loc(
|
|
indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
|
|
} else {
|
|
in_vec += indx_vec * vector_batch_stride[0];
|
|
}
|
|
|
|
if (matrix_batch_ndim > 1) {
|
|
mat += elem_to_loc(
|
|
indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
|
|
} else {
|
|
mat += indx_mat * matrix_batch_stride[0];
|
|
}
|
|
|
|
out_vec += tid.z * out_vec_size;
|
|
|
|
gemv_kernel::run(
|
|
mat,
|
|
in_vec,
|
|
bias,
|
|
out_vec,
|
|
in_vec_size,
|
|
out_vec_size,
|
|
marix_ld,
|
|
alpha,
|
|
beta,
|
|
batch_ndim, // Not used
|
|
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
|
tid,
|
|
lid,
|
|
simd_gid,
|
|
simd_lid);
|
|
}
|
|
|
|
// clang-format off
|
|
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
|
|
instantiate_kernel( \
|
|
"gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
|
"_sn" #sn "_tm" #tm "_tn" #tn, \
|
|
gemv_gather, itype, bm, bn, sm, sn, tm, tn)
|
|
|
|
#define instantiate_gemv_bs_blocks(name, itype) \
|
|
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \
|
|
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \
|
|
instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
|
|
|
|
instantiate_gemv_bs_blocks(float32, float);
|
|
instantiate_gemv_bs_blocks(float16, half);
|
|
instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
/// Vector matrix multiplication
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
typename T,
|
|
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
const int SM, /* Simdgroup rows (in threads) */
|
|
const int SN, /* Simdgroup cols (in threads) */
|
|
const int TM, /* Thread rows (in elements) */
|
|
const int TN, /* Thread cols (in elements) */
|
|
const bool kDoNCBatch, /* Batch ndim > 1 */
|
|
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
|
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t(
|
|
const device T* mat [[buffer(0)]],
|
|
const device T* in_vec [[buffer(1)]],
|
|
const device T* bias [[buffer(2)]],
|
|
device T* out_vec [[buffer(3)]],
|
|
const constant int& in_vec_size [[buffer(4)]],
|
|
const constant int& out_vec_size [[buffer(5)]],
|
|
const constant int& marix_ld [[buffer(6)]],
|
|
const constant float& alpha [[buffer(7)]],
|
|
const constant float& beta [[buffer(8)]],
|
|
const constant int& batch_ndim [[buffer(9)]],
|
|
const constant int* batch_shape [[buffer(10)]],
|
|
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
|
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
|
const constant int64_t* bias_batch_stride [[buffer(13)]],
|
|
const constant int& bias_stride [[buffer(14)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint3 lid [[thread_position_in_threadgroup]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
|
|
threadgroup float tgp_memory
|
|
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
|
|
|
// Update batch offsets
|
|
if (kDoNCBatch) {
|
|
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
|
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
|
|
|
if (kDoAxpby) {
|
|
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
|
|
}
|
|
|
|
} else {
|
|
in_vec += tid.z * vector_batch_stride[0];
|
|
mat += tid.z * matrix_batch_stride[0];
|
|
|
|
if (kDoAxpby) {
|
|
bias += tid.z * bias_batch_stride[0];
|
|
}
|
|
}
|
|
|
|
out_vec += tid.z * out_vec_size;
|
|
|
|
gemv_kernel::run(
|
|
mat,
|
|
in_vec,
|
|
bias,
|
|
out_vec,
|
|
in_vec_size,
|
|
out_vec_size,
|
|
marix_ld,
|
|
alpha,
|
|
beta,
|
|
bias_stride,
|
|
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
|
tid,
|
|
lid,
|
|
simd_gid,
|
|
simd_lid);
|
|
}
|
|
|
|
// clang-format off
|
|
#define instantiate_gemv_t_helper( \
|
|
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
|
instantiate_kernel( \
|
|
"gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
|
|
"_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby, \
|
|
gemv_t, itype, bm, bn, sm, sn, tm, tn, nc, axpby)
|
|
|
|
#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
|
|
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
|
|
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
|
|
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \
|
|
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on
|
|
|
|
// clang-format off
|
|
#define instantiate_gemv_t_blocks(name, itype) \
|
|
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 1) \
|
|
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \
|
|
instantiate_gemv_t(name, itype, 1, 4, 8, 4, 4, 4) \
|
|
instantiate_gemv_t(name, itype, 1, 16, 8, 4, 4, 4) \
|
|
instantiate_gemv_t(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on
|
|
|
|
// clang-format off
|
|
instantiate_gemv_t_blocks(float32, float);
|
|
instantiate_gemv_t_blocks(float16, half);
|
|
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
|
|
|
|
template <
|
|
typename T,
|
|
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
const int SM, /* Simdgroup rows (in threads) */
|
|
const int SN, /* Simdgroup cols (in threads) */
|
|
const int TM, /* Thread rows (in elements) */
|
|
const int TN> /* Thread cols (in elements) */
|
|
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_gather(
|
|
const device T* mat [[buffer(0)]],
|
|
const device T* in_vec [[buffer(1)]],
|
|
const device T* bias [[buffer(2)]],
|
|
device T* out_vec [[buffer(3)]],
|
|
const constant int& in_vec_size [[buffer(4)]],
|
|
const constant int& out_vec_size [[buffer(5)]],
|
|
const constant int& marix_ld [[buffer(6)]],
|
|
const constant float& alpha [[buffer(7)]],
|
|
const constant float& beta [[buffer(8)]],
|
|
const constant int& batch_ndim [[buffer(9)]],
|
|
const constant int* batch_shape [[buffer(10)]],
|
|
const constant int64_t* index_batch_strides [[buffer(11)]],
|
|
const constant int& vector_batch_ndim [[buffer(12)]],
|
|
const constant int* vector_batch_shape [[buffer(13)]],
|
|
const constant int64_t* vector_batch_stride [[buffer(14)]],
|
|
const constant int& matrix_batch_ndim [[buffer(15)]],
|
|
const constant int* matrix_batch_shape [[buffer(16)]],
|
|
const constant int64_t* matrix_batch_stride [[buffer(17)]],
|
|
const constant uint32_t* vec_indices [[buffer(18)]],
|
|
const constant uint32_t* mat_indices [[buffer(19)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint3 lid [[thread_position_in_threadgroup]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false, float>;
|
|
threadgroup float tgp_memory
|
|
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
|
|
|
uint32_t indx_vec;
|
|
uint32_t indx_mat;
|
|
|
|
// Update batch offsets
|
|
if (batch_ndim > 1) {
|
|
const constant auto* veci_bstrides = index_batch_strides;
|
|
const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
|
|
|
|
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
|
|
|
|
indx_vec = vec_indices[batch_offsets.x];
|
|
indx_mat = mat_indices[batch_offsets.y];
|
|
|
|
} else {
|
|
indx_vec = vec_indices[index_batch_strides[0] * tid.z];
|
|
indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
|
|
}
|
|
|
|
if (vector_batch_ndim > 1) {
|
|
in_vec += elem_to_loc(
|
|
indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
|
|
} else {
|
|
in_vec += indx_vec * vector_batch_stride[0];
|
|
}
|
|
|
|
if (matrix_batch_ndim > 1) {
|
|
mat += elem_to_loc(
|
|
indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
|
|
} else {
|
|
mat += indx_mat * matrix_batch_stride[0];
|
|
}
|
|
|
|
out_vec += tid.z * out_vec_size;
|
|
|
|
gemv_kernel::run(
|
|
mat,
|
|
in_vec,
|
|
bias,
|
|
out_vec,
|
|
in_vec_size,
|
|
out_vec_size,
|
|
marix_ld,
|
|
alpha,
|
|
beta,
|
|
batch_ndim, // Not used,
|
|
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
|
tid,
|
|
lid,
|
|
simd_gid,
|
|
simd_lid);
|
|
}
|
|
|
|
// clang-format off
|
|
#define instantiate_gemv_t_bs_helper( \
|
|
nm, itype, bm, bn, sm, sn, tm, tn) \
|
|
instantiate_kernel( \
|
|
"gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
|
"_sn" #sn "_tm" #tm "_tn" #tn, \
|
|
gemv_t_gather, itype, bm, bn, sm, sn, tm, tn)
|
|
|
|
#define instantiate_gemv_t_bs_blocks(name, itype) \
|
|
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \
|
|
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \
|
|
instantiate_gemv_t_bs_helper(name, itype, 1, 4, 8, 4, 4, 4) \
|
|
instantiate_gemv_t_bs_helper(name, itype, 1, 16, 8, 4, 4, 4) \
|
|
instantiate_gemv_t_bs_helper(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on
|
|
|
|
// clang-format off
|
|
instantiate_gemv_t_bs_blocks(float32, float);
|
|
instantiate_gemv_t_bs_blocks(float16, half);
|
|
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
|