From fb7be036af0d5f8ce9aa63112afbbb3c77850d39 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 16 Dec 2024 21:49:14 -0800 Subject: [PATCH] Add packed_affine_qmm_t --- ...acked_qmv_bench.py => packed_qmm_bench.py} | 30 +- mlx/backend/metal/kernels/quantized.h | 350 +++++++++++++++++- mlx/backend/metal/kernels/quantized.metal | 8 +- mlx/backend/metal/quantized.cpp | 88 ++++- 4 files changed, 455 insertions(+), 21 deletions(-) rename benchmarks/python/{packed_qmv_bench.py => packed_qmm_bench.py} (73%) diff --git a/benchmarks/python/packed_qmv_bench.py b/benchmarks/python/packed_qmm_bench.py similarity index 73% rename from benchmarks/python/packed_qmv_bench.py rename to benchmarks/python/packed_qmm_bench.py index f6c6a4724..63fd95fa9 100644 --- a/benchmarks/python/packed_qmv_bench.py +++ b/benchmarks/python/packed_qmm_bench.py @@ -1,19 +1,19 @@ import argparse import math -from functools import partial import mlx.core as mx from time_utils import time_fn +B = 1024 D = 1024 M = 4 * D group_size = 64 bits = 4 dtype = mx.float16 -loops = 100 +loops = 10 -def qmv_(x, wq1, wq2, q_type): +def qmm_(x, wq1, wq2, q_type): for i in range(loops): x = mx.quantized_matmul( x, @@ -32,28 +32,28 @@ def qmv_(x, wq1, wq2, q_type): return x -def affine_qmv(x, wq1, wq2): - return qmv_(x, wq1, wq2, "affine") +def affine_qmm(x, wq1, wq2): + return qmm_(x, wq1, wq2, "affine") -def affine_packed_qmv(x, wq1, wq2): - return qmv_(x, wq1, wq2, "affine-packed") +def affine_packed_qmm(x, wq1, wq2): + return qmm_(x, wq1, wq2, "affine-packed") -def time_qmv(): +def time_qmm(): mx.random.seed(3) - x = mx.random.normal(shape=(1, D)).astype(dtype) + x = mx.random.normal(shape=(B, D)).astype(dtype) w1 = mx.random.normal(shape=(M, D)).astype(dtype) wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine") w2 = mx.random.normal(shape=(D, M)).astype(dtype) wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine") mx.eval(x, wq1, wq2) - time_fn(affine_qmv, x, wq1, wq2) + time_fn(affine_qmm, x, wq1, wq2) -def time_packed_qmv(): +def time_packed_qmm(): mx.random.seed(3) - x = mx.random.normal(shape=(1, D)).astype(dtype) + x = mx.random.normal(shape=(B, D)).astype(dtype) w1 = mx.random.normal(shape=(M, D)).astype(dtype) wq1 = mx.quantize( w1, group_size=group_size, bits=bits, quantization_type="affine-packed" @@ -63,12 +63,12 @@ def time_packed_qmv(): w2, group_size=group_size, bits=bits, quantization_type="affine-packed" ) mx.eval(x, wq1, wq2) - time_fn(affine_packed_qmv, x, wq1, wq2) + time_fn(affine_packed_qmm, x, wq1, wq2) if __name__ == "__main__": for b in [2, 4, 8]: bits = b print(f"Bits {bits}:") - time_qmv() - time_packed_qmv() + time_qmm() + time_packed_qmm() diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 134e1df8b..c88f20923 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -1248,6 +1248,41 @@ METAL_FUNC void adjust_matrix_offsets( y += tid.z * output_stride; } +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + template METAL_FUNC void adjust_matrix_offsets( const device T*& x, @@ -2266,11 +2301,322 @@ template const device vec* scales [[buffer(1)]], const device T* x [[buffer(2)]], device T* y [[buffer(3)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { affine_packed_qmv_fast_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + short bits> +struct AffinePackedQuantizedBlockLoader { + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); + + MLX_MTL_CONST short pack_factor = 32 / bits; + MLX_MTL_CONST short row_pack_factor = 4; + MLX_MTL_CONST short BCOLS_PACKED = BCOLS * row_pack_factor / pack_factor; + MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor; + MLX_MTL_CONST short TOTAL_INTS = BCOLS_PACKED * BROWS_PACKED; + MLX_MTL_CONST short n_reads = + (TOTAL_INTS < tgp_size) ? 1 : TOTAL_INTS / tgp_size; + MLX_MTL_CONST short group_steps = group_size / BCOLS; + + static_assert( + n_reads <= row_pack_factor, + "The loader only supports per thread reads <= row_pack_factor"); + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + const short bii; + const short bjj; + + const device uint32_t* src; + const device T* scales; + const device T* biases; + threadgroup T* dst; + + AffinePackedQuantizedBlockLoader( + const device uint32_t* src_, + const device T* scales_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS_PACKED : BROWS_PACKED * src_ld), + group_step_cnt(0), + group_stride(BROWS_PACKED * 2 * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + bii(bi * row_pack_factor + bj % row_pack_factor), + bjj(bj / row_pack_factor), + src(src_ + bi * src_ld * row_pack_factor / pack_factor + bj), + scales( + scales_ + bi * 2 * src_ld * row_pack_factor / group_size + + bj % row_pack_factor), + biases(scales + row_pack_factor), + dst(dst_ + bii * dst_ld + bjj * pack_factor) {} + + void load_unsafe() const { + if (bits == 2 && BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + for (int i = 0; i < n_reads; i++) { + T scale = scales[i]; + T bias = biases[i]; + dequantize( + (const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bii >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bii >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + for (int i = 0; i < n_reads; i++) { + T scale = scales[i]; + T bias = biases[i]; + dequantize( + (const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales += 8; + biases += 8; + } + } else { + scales += 8; + biases += 8; + } + } else { + scales += group_stride; + biases += group_stride; + } + } +}; + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void affine_packed_qmm_t_impl( + const device uint32_t* w, + const device T* scales, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = 32 / bits; + constexpr int row_pack_factor = 4; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = AffinePackedQuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K * row_pack_factor / pack_factor; + const int K_g = K * 2 * row_pack_factor / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const int packed_y_col = tid.x * (BN / row_pack_factor); + + x += y_row * K; + w += packed_y_col * K_w; + scales += packed_y_col * K_g; + y += y_row * N + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + const short num_outs = min(BN, N - y_col); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(w, scales, K, Ws, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM || num_outs < BN) { + mma_op.store_result_safe(y, N, short2(num_outs, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const bool batched, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void affine_packed_qmm_t( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* x [[buffer(2)]], + device T* y [[buffer(3)]], + const constant int& K [[buffer(4)]], + const constant int& N [[buffer(5)]], + const constant int& M [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant int64_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + affine_packed_qmm_t_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 455a55ad2..76dd93adc 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -104,8 +104,12 @@ instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \ instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32) -#define instantiate_quantized_all_affine_packed(type, group_size, bits) \ - instantiate_quantized_affine_packed(affine_packed_qmv_fast, type, group_size, bits) +#define instantiate_quantized_all_affine_packed(type, group_size, bits) \ + instantiate_quantized_affine_packed(affine_packed_qmv_fast, type, group_size, bits) \ + instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, true) \ + instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, false) \ + instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, true) \ + instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, false) #define instantiate_quantized_funcs(type, group_size, bits) \ instantiate_quantized_all_single(type, group_size, bits) \ diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index bf60d17d1..925bf8560 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -428,8 +428,91 @@ void affine_packed_qmv( compute_encoder.set_input_array(scales, 1); compute_encoder.set_input_array(x, 2); compute_encoder.set_output_array(out, 3); - compute_encoder.set_bytes(D, 5); - compute_encoder.set_bytes(O, 6); + compute_encoder.set_bytes(D, 4); + compute_encoder.set_bytes(O, 5); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void affine_packed_qmm_t( + const std::vector& inputs, + array& out, + bool batched, + int B, + int D, + int O, + int group_size, + int bits, + const Stream& s) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& d = metal::device(s.device); + auto ensure_row_contiguous_last_dims = [&d, &s](const array& arr) { + auto stride_0 = arr.strides()[arr.ndim() - 2]; + auto stride_1 = arr.strides()[arr.ndim() - 1]; + if (stride_0 == arr.shape(-1) && stride_1 == 1) { + return arr; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + d.add_temporary(arr_copy, s.index); + return arr_copy; + } + }; + // TODO: Deal with this in routing towards qmm_n instead of qmm_t + auto x = ensure_row_contiguous_last_dims(inputs[0]); + auto w = ensure_row_contiguous_last_dims(inputs[1]); + auto scales = ensure_row_contiguous_last_dims(inputs[2]); + + int x_batch_ndims = x.ndim() - 2; + auto& x_shape = x.shape(); + auto& x_strides = x.strides(); + int w_batch_ndims = w.ndim() - 2; + auto& w_shape = w.shape(); + auto& w_strides = w.strides(); + auto& s_strides = scales.strides(); + + const int wn = 2; + const int wm = 2; + const int bm = 32; + const int bn = 32; + const int N = (batched) ? out.size() / B / O : 1; + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((O + bn - 1) / bn, (B + bm - 1) / bm, N); + + std::string name; + name.reserve(64); + concatenate( + name, + "affine_packed_qmm_t_", + get_type_string(out.dtype()), + "_gs_", + std::to_string(group_size), + "_b_", + std::to_string(bits), + "_alN_", + ((O % 32) == 0) ? "true" : "false", + "_batch_", + (batched) ? "true" : "false"); + auto kernel = get_quantized_kernel(d, name, ""); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(x, 2); + compute_encoder.set_output_array(out, 3); + compute_encoder.set_bytes(D, 4); + compute_encoder.set_bytes(O, 5); + compute_encoder.set_bytes(B, 6); + if (batched) { + compute_encoder.set_bytes(x_batch_ndims, 7); + compute_encoder.set_vector_bytes(x_shape, 8); + compute_encoder.set_vector_bytes(x_strides, 9); + compute_encoder.set_bytes(w_batch_ndims, 10); + compute_encoder.set_vector_bytes(w_shape, 11); + compute_encoder.set_vector_bytes(w_strides, 12); + compute_encoder.set_vector_bytes(s_strides, 13); + } compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -451,6 +534,7 @@ void affine_packed_qmm_op( if (B < 6) { affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s); } else { + affine_packed_qmm_t(inputs, out, batched, B, D, O, group_size, bits, s); } } else { }