diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 5bf3142d4..0de84093d 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -15,6 +15,14 @@ using namespace metal; MLX_MTL_CONST int SIMD_SIZE = 32; +template struct AccT { + typedef T acc_t; +}; + +template <> struct AccT { + typedef float acc_t; +}; + template [[kernel]] void qmv( const device uint32_t* w [[buffer(0)]], @@ -37,15 +45,16 @@ template ::acc_t U; + threadgroup U scales_block[BM * groups_per_block]; + threadgroup U biases_block[BM * groups_per_block]; + threadgroup U x_block[colgroup]; thread uint32_t w_local; - thread T result = 0; - thread T scale = 1; - thread T bias = 0; - thread T x_thread[el_per_thread]; + thread U result = 0; + thread U scale = 1; + thread U bias = 0; + thread U x_thread[el_per_thread]; // Adjust positions const int in_vec_size_w = in_vec_size / el_per_thread; @@ -90,7 +99,7 @@ template (w_local & bitmask) + bias) * x_thread[k]; + result += (scale * static_cast(w_local & bitmask) + bias) * x_thread[k]; w_local >>= bits; } } @@ -100,7 +109,7 @@ template (result); } } @@ -129,15 +138,16 @@ template ::acc_t U; + threadgroup U scales_block[BM * groups_per_block]; + threadgroup U biases_block[BM * groups_per_block]; + threadgroup U x_block[BM]; thread uint32_t w_local; - thread T result[el_per_int] = {0}; - thread T scale = 1; - thread T bias = 0; - thread T x_local = 0; + thread U result[el_per_int] = {0}; + thread U scale = 1; + thread U bias = 0; + thread U x_local = 0; // Adjust positions const int out_vec_size_w = out_vec_size / el_per_int; @@ -186,7 +196,7 @@ template (w_local & bitmask) + bias) * x_local; + result[k] += (scale * static_cast(w_local & bitmask) + bias) * x_local; w_local >>= bits; } } @@ -201,7 +211,7 @@ template (result[k]); } } }