diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index 28cadd50a..f21c35d97 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -23,7 +23,8 @@ template < const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ - const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ + const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ + typename AccT = float> struct GEMVKernel { MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsN = BN * SN; @@ -60,28 +61,32 @@ struct GEMVKernel { MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; + template static METAL_FUNC void - load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) { + load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { - dst[tn] = src[src_offset + tn]; + dst[tn] = static_cast(src[src_offset + tn]); } } + template static METAL_FUNC void load_safe( const device T* src, - thread T dst[TN], + thread U dst[TN], const int src_offset = 0, const int src_size = TN) { if (src_offset + TN <= src_size) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { - dst[tn] = src[src_offset + tn]; + dst[tn] = static_cast(src[src_offset + tn]); } } else { // Edgecase MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { - dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0; + dst[tn] = src_offset + tn < src_size + ? static_cast(src[src_offset + tn]) + : U(0); } } } @@ -97,7 +102,7 @@ struct GEMVKernel { const constant float& alpha [[buffer(7)]], const constant float& beta [[buffer(8)]], const constant int& bias_stride [[buffer(14)]], - threadgroup T* tgp_memory [[threadgroup(0)]], + threadgroup AccT* tgp_memory [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -106,9 +111,9 @@ struct GEMVKernel { (void)lid; // Thread local accumulation results - thread T result[TM] = {0}; + thread AccT result[TM] = {0}; thread T inter[TN]; - thread T v_coeff[TN]; + thread AccT v_coeff[TN]; const int thrM = SN != 32 ? simd_lid / SN : 0; const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); @@ -142,7 +147,7 @@ struct GEMVKernel { // Loop over in_vec in blocks of blockN for (int i = 0; i < n_iter; ++i) { - load_unsafe(in_vec, v_coeff, bn); + load_unsafe(in_vec, v_coeff, bn); // Per thread work loop int mat_offset = 0; @@ -164,7 +169,7 @@ struct GEMVKernel { } if (leftover > 0) { - load_safe(in_vec, v_coeff, bn, in_size); + load_safe(in_vec, v_coeff, bn, in_size); // Per thread work loop MLX_MTL_PRAGMA_UNROLL @@ -191,7 +196,7 @@ struct GEMVKernel { // Threadgroup accumulation results if (needs_tgp_reduction) { - threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; + threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; if (thrN == 0) { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { @@ -217,10 +222,11 @@ struct GEMVKernel { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { if (kDoAxpby) { - out_vec[out_row + tm] = static_cast(alpha) * result[tm] + + out_vec[out_row + tm] = + static_cast(alpha) * static_cast(result[tm]) + static_cast(beta) * bias[(out_row + tm) * bias_stride]; } else { - out_vec[out_row + tm] = result[tm]; + out_vec[out_row + tm] = static_cast(result[tm]); } } } @@ -239,7 +245,8 @@ template < const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ - const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ + const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ + typename AccT = float> struct GEMVTKernel { MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsN = BN * SN; @@ -282,7 +289,7 @@ struct GEMVTKernel { const constant float& alpha [[buffer(7)]], const constant float& beta [[buffer(8)]], const constant int& bias_stride [[buffer(14)]], - threadgroup T* tgp_memory [[threadgroup(0)]], + threadgroup AccT* tgp_memory [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -291,10 +298,9 @@ struct GEMVTKernel { (void)lid; // Thread local accumulation results - T result[TN] = {0}; + AccT result[TN] = {0}; T inter[TN]; - T v_coeff[TM]; - + AccT v_coeff[TM]; const int thrM = SN != 32 ? simd_lid / SN : 0; const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); @@ -330,16 +336,17 @@ struct GEMVTKernel { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { - v_coeff[tm] = in_vec[bm + tm]; + v_coeff[tm] = static_cast(in_vec[bm + tm]); } MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { + auto vc = float(v_coeff[tm]); for (int tn = 0; tn < TN; tn++) { inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; } for (int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; + result[tn] += vc * inter[tn]; } } @@ -348,7 +355,7 @@ struct GEMVTKernel { if (leftover > 0) { for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { - v_coeff[tm] = in_vec[bm + tm]; + v_coeff[tm] = static_cast(in_vec[bm + tm]); MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { @@ -374,7 +381,7 @@ struct GEMVTKernel { // Threadgroup accumulation results if (needs_tgp_reduction) { - threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; + threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; if (thrM == 0) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { @@ -400,10 +407,11 @@ struct GEMVTKernel { MLX_MTL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { if (kDoAxpby) { - out_vec[out_col + j] = static_cast(alpha) * result[j] + + out_vec[out_col + j] = + static_cast(alpha) * static_cast(result[j]) + static_cast(beta) * bias[(out_col + j) * bias_stride]; } else { - out_vec[out_col + j] = result[j]; + out_vec[out_col + j] = static_cast(result[j]); } } } @@ -445,7 +453,7 @@ template < uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVKernel; - threadgroup T tgp_memory + threadgroup float tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; // Update batch offsets @@ -553,7 +561,7 @@ template < uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVKernel; - threadgroup T tgp_memory + threadgroup float tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; uint32_t indx_vec; @@ -660,7 +668,7 @@ template < uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVTKernel; - threadgroup T tgp_memory + threadgroup float tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; // Update batch offsets @@ -761,8 +769,8 @@ template < uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVTKernel; - threadgroup T tgp_memory + using gemv_kernel = GEMVTKernel; + threadgroup float tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; uint32_t indx_vec; diff --git a/mlx/backend/metal/kernels/gemv_masked.h b/mlx/backend/metal/kernels/gemv_masked.h index 48acf1d61..75bc7354c 100644 --- a/mlx/backend/metal/kernels/gemv_masked.h +++ b/mlx/backend/metal/kernels/gemv_masked.h @@ -44,7 +44,8 @@ template < const int SM, /* Simdgroup rows (in threads) */ const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ - const int TN> /* Thread cols (in elements) */ + const int TN, /* Thread cols (in elements) */ + typename AccT = float> struct GEMVKernel { MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsN = BN * SN; @@ -91,28 +92,32 @@ struct GEMVKernel { MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; + template static METAL_FUNC void - load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) { + load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { - dst[tn] = src[src_offset + tn]; + dst[tn] = static_cast(src[src_offset + tn]); } } + template static METAL_FUNC void load_safe( const device T* src, - thread T dst[TN], + thread U dst[TN], const int src_offset = 0, const int src_size = TN) { if (src_offset + TN <= src_size) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { - dst[tn] = src[src_offset + tn]; + dst[tn] = static_cast(src[src_offset + tn]); } } else { // Edgecase MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { - dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0; + dst[tn] = src_offset + tn < src_size + ? static_cast(src[src_offset + tn]) + : U(0); } } } @@ -128,7 +133,7 @@ struct GEMVKernel { const device op_mask_t* mat_mask [[buffer(21)]], const device op_mask_t* vec_mask [[buffer(22)]], const constant int* mask_strides [[buffer(23)]], - threadgroup T* tgp_memory [[threadgroup(0)]], + threadgroup AccT* tgp_memory [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -137,9 +142,9 @@ struct GEMVKernel { (void)lid; // Thread local accumulation results - thread T result[TM] = {0}; + thread AccT result[TM] = {0}; thread T inter[TN]; - thread T v_coeff[TN]; + thread AccT v_coeff[TN]; const int thrM = SN != 32 ? simd_lid / SN : 0; const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); @@ -225,7 +230,7 @@ struct GEMVKernel { T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); } - load_unsafe(in_vec, v_coeff, bn); + load_unsafe(in_vec, v_coeff, bn); // Apply scale if (has_mul_operand_mask) { @@ -267,7 +272,7 @@ struct GEMVKernel { T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); } - load_safe(in_vec, v_coeff, bn, in_size); + load_safe(in_vec, v_coeff, bn, in_size); // Apply scale if (has_mul_operand_mask) { @@ -310,7 +315,7 @@ struct GEMVKernel { // Threadgroup accumulation results if (needs_tgp_reduction) { - threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; + threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; if (thrN == 0) { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { @@ -335,7 +340,7 @@ struct GEMVKernel { if (simdN == 0 && thrN == 0) { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { - out_vec[out_row + tm] = result[tm]; + out_vec[out_row + tm] = static_cast(result[tm]); } } } @@ -354,7 +359,8 @@ template < const int SM, /* Simdgroup rows (in threads) */ const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ - const int TN> /* Thread cols (in elements) */ + const int TN, /* Thread cols (in elements) */ + typename AccT = float> struct GEMVTKernel { MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsN = BN * SN; @@ -405,7 +411,7 @@ struct GEMVTKernel { const device op_mask_t* mat_mask [[buffer(21)]], const device op_mask_t* vec_mask [[buffer(22)]], const constant int* mask_strides [[buffer(23)]], - threadgroup T* tgp_memory [[threadgroup(0)]], + threadgroup AccT* tgp_memory [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -414,9 +420,9 @@ struct GEMVTKernel { (void)lid; // Thread local accumulation results - T result[TN] = {0}; + AccT result[TN] = {0}; T inter[TN]; - T v_coeff[TM]; + AccT v_coeff[TM]; const int thrM = SN != 32 ? simd_lid / SN : 0; const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); @@ -511,7 +517,7 @@ struct GEMVTKernel { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { - v_coeff[tm] = in_vec[bm + tm]; + v_coeff[tm] = static_cast(in_vec[bm + tm]); } // Apply scale @@ -549,7 +555,7 @@ struct GEMVTKernel { } for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { - v_coeff[tm] = in_vec[bm + tm]; + v_coeff[tm] = static_cast(in_vec[bm + tm]); if (has_mul_operand_mask) { v_coeff[tm] *= block_scale; @@ -587,7 +593,7 @@ struct GEMVTKernel { // Threadgroup accumulation results if (needs_tgp_reduction) { - threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; + threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; if (thrM == 0) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { @@ -612,7 +618,7 @@ struct GEMVTKernel { if (cm == 0 && out_col < out_vec_size) { MLX_MTL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { - out_vec[out_col + j] = result[j]; + out_vec[out_col + j] = static_cast(result[j]); } } } @@ -655,7 +661,7 @@ template < uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVKernel; - threadgroup T tgp_memory + threadgroup float tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; constexpr bool has_operand_mask = !metal::is_same_v; @@ -755,7 +761,7 @@ template < uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVTKernel; - threadgroup T tgp_memory + threadgroup float tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; constexpr bool has_operand_mask = !metal::is_same_v; diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index fdeaea98a..985ca5ffb 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1146,6 +1146,18 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertEqual(r.shape, t.shape) self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) + def test_gemv_gemm_same_precision(self): + mx.random.seed(0) + N = 256 + if mx.metal.is_available(): + t = mx.bfloat16 + a = mx.random.normal([1, N]).astype(t) + b = mx.concatenate([a, a], axis=0).astype(t) + c = mx.random.normal([N, 64]).astype(t) + out_gemv = a @ c + out_gemm = (b @ c)[0] + self.assertTrue(mx.allclose(out_gemv, out_gemm)) + if __name__ == "__main__": unittest.main()