diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 7f4b76725..625d7450a 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -201,6 +201,150 @@ inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* resu } } +template +inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { + static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); + + if (bits == 2) { + U s[4] = {scale, scale / static_cast(4.0f), scale / static_cast(16.0f), scale / static_cast(64.0f)}; + for (int i = 0; i < (N / 4); i++) { + w_local[4*i] = s[0] * (w[i] & 0x03) + bias; + w_local[4*i+1] = s[1] * (w[i] & 0x0c) + bias; + w_local[4*i+2] = s[2] * (w[i] & 0x30) + bias; + w_local[4*i+3] = s[3] * (w[i] & 0xc0) + bias; + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + U s[4] = {scale, scale / static_cast(16.0f), scale / static_cast(256.0f), scale / static_cast(4096.0f)}; + for (int i = 0; i < (N / 4); i++) { + w_local[4*i] = s[0] * (ws[i] & 0x000f) + bias; + w_local[4*i+1] = s[1] * (ws[i] & 0x00f0) + bias; + w_local[4*i+2] = s[2] * (ws[i] & 0x0f00) + bias; + w_local[4*i+3] = s[3] * (ws[i] & 0xf000) + bias; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + w_local[i] = scale * w[i] + bias; + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + short bits> +struct QuantizedBlockLoader { + static_assert(BCOLS <= group_size, "The group size should be larger than the columns"); + static_assert(group_size % BCOLS == 0, "The group size should be divisible by the columns"); + static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); + + MLX_MTL_CONST short pack_factor = 32 / bits; + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short group_steps = group_size / BCOLS; + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint32_t* src; + const device T* scales; + const device T* biases; + + QuantizedBlockLoader( + const device uint32_t* src_, + const device T* scales_, + const device T* biases_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor), + group_step_cnt(0), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld / pack_factor + bj), + scales(scales_ + bi * src_ld / group_size), + biases(biases_ + bi * src_ld / group_size) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = *scales; + T bias = *biases; + for (int i=0; i((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.y) { + for (int i=0; i= src_tile_dim.x) { + for (int i=0; i((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales++; + biases++; + } + } else { + scales++; + biases++; + } + } else { + scales += group_stride; + biases += group_stride; + } + } +}; + template [[kernel]] void qmv_fast( const device uint32_t* w [[buffer(0)]], @@ -495,28 +639,23 @@ template = SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); - const uint lidy = lid / SIMD_SIZE; + (void)lid; constexpr int WM = 2; constexpr int WN = 2; - constexpr int bitmask = (1 << bits) - 1; - constexpr int el_per_int = 32 / bits; - constexpr int ints_per_block = BK / el_per_int; - constexpr int groups_per_block = (BK / group_size > 0) ? (BK / group_size) : 1; - constexpr int groups_per_simd = BN / (WM * WN); - constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN); + constexpr int pack_factor = 32 / bits; + constexpr int BK_padded = (BK + 16 / sizeof(T)); // Instantiate the appropriate BlockMMA and Loader - using mma_t = mlx::steel::BlockMMA; - using loader_x_t = mlx::steel::BlockLoader; + using mma_t = mlx::steel::BlockMMA; + using loader_x_t = mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader; - threadgroup T scales_block[BN * groups_per_block]; - threadgroup T biases_block[BN * groups_per_block]; - threadgroup T Xs[BM * BK]; - threadgroup T Ws[BN * BK]; + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; // Set the block - const int K_w = K / el_per_int; + const int K_w = K / pack_factor; const int K_g = K / group_size; const int y_row = tid.y * BM; const int y_col = tid.x * BN; @@ -531,98 +670,53 @@ template (wi & bitmask) + bias; - wi >>= bits; - } - } else { - #pragma clang loop unroll(full) - for (int t=0; t(wi & bitmask) + bias; - wi >>= bits; - } - } - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(Xs, Ws); - - // Prepare for next iteration - loader_x.next(); - w += ints_per_block; - // scales and biases cannot be advanced because they would have to be - // advanced every other iteration or sth. } // Store results to device memory @@ -653,32 +747,27 @@ template = SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); - const uint lidy = lid / SIMD_SIZE; + (void)lid; constexpr int WM = 2; constexpr int WN = 2; - constexpr int bitmask = (1 << bits) - 1; - constexpr int el_per_int = 32 / bits; - constexpr int groups_per_block = (BN / group_size > 0) ? (BN / group_size) : 1; - constexpr int groups_per_simd = BK / (WM * WN); - constexpr int w_els_per_thread = (BK * BN / el_per_int) / (SIMD_SIZE * WM * WN); + constexpr int pack_factor = 32 / bits; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); // Instantiate the appropriate BlockMMA and Loader - using mma_t = mlx::steel::BlockMMA; - using loader_x_t = mlx::steel::BlockLoader; + using mma_t = mlx::steel::BlockMMA; + using loader_x_t = mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader; - threadgroup T scales_block[BK * groups_per_block]; - threadgroup T biases_block[BK * groups_per_block]; - threadgroup T Xs[BM * BK]; - threadgroup T Ws[BK * BN]; + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; // Set the block - const int N_w = N / el_per_int; - const int N_g = N / group_size; const int y_row = tid.y * BM; const int y_col = tid.x * BN; x += y_row * K; - w += y_col / el_per_int; + w += y_col / pack_factor; scales += y_col / group_size; biases += y_col / group_size; y += y_row * N + y_col; @@ -686,96 +775,67 @@ template (wi & bitmask) + bias; - wi >>= bits; - } - } else { - #pragma clang loop unroll(full) - for (int t=0; t(wi & bitmask) + bias; - wi >>= bits; - } - } - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(Xs, Ws); - - // Prepare for next iteration - loader_x.next(); - w += BK * N_w; - scales += BK * N_g; - biases += BK * N_g; } // Store results to device memory @@ -877,7 +937,7 @@ instantiate_qvm_types( 32, 8) #define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \ template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \ - [[kernel]] void qmm_t( \ + [[kernel]] void qmm_t( \ const device itype* x [[buffer(0)]], \ const device uint32_t* w [[buffer(1)]], \ const device itype* scales [[buffer(2)]], \ @@ -911,7 +971,7 @@ instantiate_qmm_t_types( 32, 8) #define instantiate_qmm_n(name, itype, group_size, bits) \ template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \ - [[kernel]] void qmm_n( \ + [[kernel]] void qmm_n( \ const device itype* x [[buffer(0)]], \ const device uint32_t* w [[buffer(1)]], \ const device itype* scales [[buffer(2)]], \ diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index b41ee68e2..eb060e7e9 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -110,7 +110,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { int wm = 2; int bm = 32; int bn = 32; - int bk = 64; + int bk = 32; MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1); @@ -167,7 +167,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { int wn = 2; int wm = 2; int bm = 32; - int bn = 64; + int bn = 32; int bk = 32; MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);