mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-25 12:48:14 +08:00 
			
		
		
		
	Faster bfloat quantized mat-vec and vec-mat (#663)
This commit is contained in:
		| @@ -15,6 +15,14 @@ using namespace metal; | ||||
|  | ||||
| MLX_MTL_CONST int SIMD_SIZE = 32; | ||||
|  | ||||
| template <typename T> struct AccT { | ||||
|   typedef T acc_t; | ||||
| }; | ||||
|  | ||||
| template <> struct AccT<bfloat16_t> { | ||||
|   typedef float acc_t; | ||||
| }; | ||||
|  | ||||
| template <typename T, const int BM, const int BN, const int group_size, const int bits> | ||||
| [[kernel]] void qmv( | ||||
|     const device uint32_t* w [[buffer(0)]], | ||||
| @@ -37,15 +45,16 @@ template <typename T, const int BM, const int BN, const int group_size, const in | ||||
|   constexpr int groups_per_block = colgroup / group_size; | ||||
|   constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE; | ||||
|  | ||||
|   threadgroup T scales_block[BM * groups_per_block]; | ||||
|   threadgroup T biases_block[BM * groups_per_block]; | ||||
|   threadgroup T x_block[colgroup]; | ||||
|   typedef typename AccT<T>::acc_t U; | ||||
|   threadgroup U scales_block[BM * groups_per_block]; | ||||
|   threadgroup U biases_block[BM * groups_per_block]; | ||||
|   threadgroup U x_block[colgroup]; | ||||
|  | ||||
|   thread uint32_t w_local; | ||||
|   thread T result = 0; | ||||
|   thread T scale = 1; | ||||
|   thread T bias = 0; | ||||
|   thread T x_thread[el_per_thread]; | ||||
|   thread U result = 0; | ||||
|   thread U scale = 1; | ||||
|   thread U bias = 0; | ||||
|   thread U x_thread[el_per_thread]; | ||||
|  | ||||
|   // Adjust positions | ||||
|   const int in_vec_size_w = in_vec_size / el_per_thread; | ||||
| @@ -90,7 +99,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in | ||||
|     // Do all the work. | ||||
|     #pragma clang loop unroll(full) | ||||
|     for (int k=0; k<el_per_thread; k++) { | ||||
|       result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k]; | ||||
|       result += (scale * static_cast<U>(w_local & bitmask) + bias) * x_thread[k]; | ||||
|       w_local >>= bits; | ||||
|     } | ||||
|   } | ||||
| @@ -100,7 +109,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in | ||||
|  | ||||
|   // Store the result | ||||
|   if (simd_lid == 0) { | ||||
|     y[out_row] = result; | ||||
|     y[out_row] = static_cast<T>(result); | ||||
|   } | ||||
| } | ||||
|  | ||||
| @@ -129,15 +138,16 @@ template <typename T, const int BM, const int BN, const int group_size, const in | ||||
|   constexpr int colgroup = BN * el_per_int; | ||||
|   constexpr int groups_per_block = colgroup / group_size; | ||||
|  | ||||
|   threadgroup T scales_block[BM * groups_per_block]; | ||||
|   threadgroup T biases_block[BM * groups_per_block]; | ||||
|   threadgroup T x_block[BM]; | ||||
|   typedef typename AccT<T>::acc_t U; | ||||
|   threadgroup U scales_block[BM * groups_per_block]; | ||||
|   threadgroup U biases_block[BM * groups_per_block]; | ||||
|   threadgroup U x_block[BM]; | ||||
|  | ||||
|   thread uint32_t w_local; | ||||
|   thread T result[el_per_int] = {0}; | ||||
|   thread T scale = 1; | ||||
|   thread T bias = 0; | ||||
|   thread T x_local = 0; | ||||
|   thread U result[el_per_int] = {0}; | ||||
|   thread U scale = 1; | ||||
|   thread U bias = 0; | ||||
|   thread U x_local = 0; | ||||
|  | ||||
|   // Adjust positions | ||||
|   const int out_vec_size_w = out_vec_size / el_per_int; | ||||
| @@ -186,7 +196,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in | ||||
|     // Do all the work. | ||||
|     #pragma clang loop unroll(full) | ||||
|     for (int k=0; k<el_per_int; k++) { | ||||
|       result[k] += (scale * static_cast<T>(w_local & bitmask) + bias) * x_local; | ||||
|       result[k] += (scale * static_cast<U>(w_local & bitmask) + bias) * x_local; | ||||
|       w_local >>= bits; | ||||
|     } | ||||
|   } | ||||
| @@ -201,7 +211,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in | ||||
|   if (simd_lid == 0) { | ||||
|     #pragma clang loop unroll(full) | ||||
|     for (int k=0; k<el_per_int; k++) { | ||||
|       y[k] = result[k]; | ||||
|       y[k] = static_cast<T>(result[k]); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun