Use same accumulation precision in gemv as gemm (#1962)

* use same accumulation precision in gemv as gemm

* faster

* fix compile
This commit is contained in:
Awni Hannun 2025-03-16 07:13:24 -07:00 committed by GitHub
parent 2770a10240
commit c6ea2ba329
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 79 additions and 53 deletions

View File

@ -23,7 +23,8 @@ template <
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
typename AccT = float>
struct GEMVKernel {
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
@ -60,28 +61,32 @@ struct GEMVKernel {
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
template <typename U = T>
static METAL_FUNC void
load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) {
load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src[src_offset + tn];
dst[tn] = static_cast<U>(src[src_offset + tn]);
}
}
template <typename U = T>
static METAL_FUNC void load_safe(
const device T* src,
thread T dst[TN],
thread U dst[TN],
const int src_offset = 0,
const int src_size = TN) {
if (src_offset + TN <= src_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src[src_offset + tn];
dst[tn] = static_cast<U>(src[src_offset + tn]);
}
} else { // Edgecase
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
dst[tn] = src_offset + tn < src_size
? static_cast<U>(src[src_offset + tn])
: U(0);
}
}
}
@ -97,7 +102,7 @@ struct GEMVKernel {
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& bias_stride [[buffer(14)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
threadgroup AccT* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@ -106,9 +111,9 @@ struct GEMVKernel {
(void)lid;
// Thread local accumulation results
thread T result[TM] = {0};
thread AccT result[TM] = {0};
thread T inter[TN];
thread T v_coeff[TN];
thread AccT v_coeff[TN];
const int thrM = SN != 32 ? simd_lid / SN : 0;
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
@ -142,7 +147,7 @@ struct GEMVKernel {
// Loop over in_vec in blocks of blockN
for (int i = 0; i < n_iter; ++i) {
load_unsafe(in_vec, v_coeff, bn);
load_unsafe<AccT>(in_vec, v_coeff, bn);
// Per thread work loop
int mat_offset = 0;
@ -164,7 +169,7 @@ struct GEMVKernel {
}
if (leftover > 0) {
load_safe(in_vec, v_coeff, bn, in_size);
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
// Per thread work loop
MLX_MTL_PRAGMA_UNROLL
@ -191,7 +196,7 @@ struct GEMVKernel {
// Threadgroup accumulation results
if (needs_tgp_reduction) {
threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
if (thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
@ -217,10 +222,11 @@ struct GEMVKernel {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
if (kDoAxpby) {
out_vec[out_row + tm] = static_cast<T>(alpha) * result[tm] +
out_vec[out_row + tm] =
static_cast<T>(alpha) * static_cast<T>(result[tm]) +
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
} else {
out_vec[out_row + tm] = result[tm];
out_vec[out_row + tm] = static_cast<T>(result[tm]);
}
}
}
@ -239,7 +245,8 @@ template <
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
typename AccT = float>
struct GEMVTKernel {
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
@ -282,7 +289,7 @@ struct GEMVTKernel {
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& bias_stride [[buffer(14)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
threadgroup AccT* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@ -291,10 +298,9 @@ struct GEMVTKernel {
(void)lid;
// Thread local accumulation results
T result[TN] = {0};
AccT result[TN] = {0};
T inter[TN];
T v_coeff[TM];
AccT v_coeff[TM];
const int thrM = SN != 32 ? simd_lid / SN : 0;
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
@ -330,16 +336,17 @@ struct GEMVTKernel {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm];
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
}
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
auto vc = float(v_coeff[tm]);
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
result[tn] += vc * inter[tn];
}
}
@ -348,7 +355,7 @@ struct GEMVTKernel {
if (leftover > 0) {
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
v_coeff[tm] = in_vec[bm + tm];
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
@ -374,7 +381,7 @@ struct GEMVTKernel {
// Threadgroup accumulation results
if (needs_tgp_reduction) {
threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
if (thrM == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
@ -400,10 +407,11 @@ struct GEMVTKernel {
MLX_MTL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
if (kDoAxpby) {
out_vec[out_col + j] = static_cast<T>(alpha) * result[j] +
out_vec[out_col + j] =
static_cast<T>(alpha) * static_cast<T>(result[j]) +
static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
} else {
out_vec[out_col + j] = result[j];
out_vec[out_col + j] = static_cast<T>(result[j]);
}
}
}
@ -445,7 +453,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
threadgroup T tgp_memory
threadgroup float tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
// Update batch offsets
@ -553,7 +561,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, false>;
threadgroup T tgp_memory
threadgroup float tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
uint32_t indx_vec;
@ -660,7 +668,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
threadgroup T tgp_memory
threadgroup float tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
// Update batch offsets
@ -761,8 +769,8 @@ template <
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false>;
threadgroup T tgp_memory
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false, float>;
threadgroup float tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
uint32_t indx_vec;

View File

@ -44,7 +44,8 @@ template <
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
const int TN, /* Thread cols (in elements) */
typename AccT = float>
struct GEMVKernel {
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
@ -91,28 +92,32 @@ struct GEMVKernel {
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
template <typename U = T>
static METAL_FUNC void
load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) {
load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src[src_offset + tn];
dst[tn] = static_cast<U>(src[src_offset + tn]);
}
}
template <typename U = T>
static METAL_FUNC void load_safe(
const device T* src,
thread T dst[TN],
thread U dst[TN],
const int src_offset = 0,
const int src_size = TN) {
if (src_offset + TN <= src_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src[src_offset + tn];
dst[tn] = static_cast<U>(src[src_offset + tn]);
}
} else { // Edgecase
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
dst[tn] = src_offset + tn < src_size
? static_cast<U>(src[src_offset + tn])
: U(0);
}
}
}
@ -128,7 +133,7 @@ struct GEMVKernel {
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
threadgroup AccT* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@ -137,9 +142,9 @@ struct GEMVKernel {
(void)lid;
// Thread local accumulation results
thread T result[TM] = {0};
thread AccT result[TM] = {0};
thread T inter[TN];
thread T v_coeff[TN];
thread AccT v_coeff[TN];
const int thrM = SN != 32 ? simd_lid / SN : 0;
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
@ -225,7 +230,7 @@ struct GEMVKernel {
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
load_unsafe(in_vec, v_coeff, bn);
load_unsafe<AccT>(in_vec, v_coeff, bn);
// Apply scale
if (has_mul_operand_mask) {
@ -267,7 +272,7 @@ struct GEMVKernel {
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
load_safe(in_vec, v_coeff, bn, in_size);
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
// Apply scale
if (has_mul_operand_mask) {
@ -310,7 +315,7 @@ struct GEMVKernel {
// Threadgroup accumulation results
if (needs_tgp_reduction) {
threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
if (thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
@ -335,7 +340,7 @@ struct GEMVKernel {
if (simdN == 0 && thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = result[tm];
out_vec[out_row + tm] = static_cast<T>(result[tm]);
}
}
}
@ -354,7 +359,8 @@ template <
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
const int TN, /* Thread cols (in elements) */
typename AccT = float>
struct GEMVTKernel {
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
@ -405,7 +411,7 @@ struct GEMVTKernel {
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
threadgroup AccT* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@ -414,9 +420,9 @@ struct GEMVTKernel {
(void)lid;
// Thread local accumulation results
T result[TN] = {0};
AccT result[TN] = {0};
T inter[TN];
T v_coeff[TM];
AccT v_coeff[TM];
const int thrM = SN != 32 ? simd_lid / SN : 0;
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
@ -511,7 +517,7 @@ struct GEMVTKernel {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm];
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
}
// Apply scale
@ -549,7 +555,7 @@ struct GEMVTKernel {
}
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
v_coeff[tm] = in_vec[bm + tm];
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
if (has_mul_operand_mask) {
v_coeff[tm] *= block_scale;
@ -587,7 +593,7 @@ struct GEMVTKernel {
// Threadgroup accumulation results
if (needs_tgp_reduction) {
threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
if (thrM == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
@ -612,7 +618,7 @@ struct GEMVTKernel {
if (cm == 0 && out_col < out_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
out_vec[out_col + j] = result[j];
out_vec[out_col + j] = static_cast<T>(result[j]);
}
}
}
@ -655,7 +661,7 @@ template <
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel =
GEMVKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
threadgroup T tgp_memory
threadgroup float tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
@ -755,7 +761,7 @@ template <
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel =
GEMVTKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
threadgroup T tgp_memory
threadgroup float tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;

View File

@ -1146,6 +1146,18 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
def test_gemv_gemm_same_precision(self):
mx.random.seed(0)
N = 256
if mx.metal.is_available():
t = mx.bfloat16
a = mx.random.normal([1, N]).astype(t)
b = mx.concatenate([a, a], axis=0).astype(t)
c = mx.random.normal([N, 64]).astype(t)
out_gemv = a @ c
out_gemm = (b @ c)[0]
self.assertTrue(mx.allclose(out_gemv, out_gemm))
if __name__ == "__main__":
unittest.main()