An initial quantized matmul implementation (#205)

* Add quantized matvec
* Add quantized matrix matrix with 2nd matrix transposed
* Add quantized matmul tests
* Add a slow cpu quantized matmul
* Add a slightly faster vectorized cpu version
This commit is contained in:
Angelos Katharopoulos
2023-12-18 23:18:57 -08:00
committed by GitHub
parent e6872a4149
commit dfa9f4bc58
18 changed files with 1029 additions and 10 deletions

View File

@@ -10,6 +10,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp

View File

@@ -18,6 +18,7 @@ set(
"copy"
"gemm"
"gemv"
"quantized"
"random"
"reduce"
"scan"

View File

@@ -0,0 +1,287 @@
// Copyright © 2023 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
MLX_MTL_CONST int SIMD_SIZE = 32;
template <typename T, const int BM, const int BN, const int groups, const int width>
[[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
constexpr int bitmask = (1 << width) - 1;
constexpr int el_per_thread = 32 / width;
constexpr int colgroup = BN * el_per_thread;
constexpr int groups_per_block = colgroup / groups;
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
threadgroup T scales_block[BM * groups_per_block];
threadgroup T biases_block[BM * groups_per_block];
threadgroup T x_block[colgroup];
thread uint32_t w_local;
thread T result = 0;
thread T scale = 1;
thread T bias = 0;
thread T x_thread[el_per_thread];
// Adjust positions
const int in_vec_size_w = in_vec_size / el_per_thread;
const int in_vec_size_g = in_vec_size / groups;
int out_row = tid.y * BM + simd_gid;
w += out_row * in_vec_size_w;
scales += out_row * in_vec_size_g;
biases += out_row * in_vec_size_g;
x += tid.z * in_vec_size;
y += tid.z * out_vec_size;
// Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=colgroup) {
// Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid < simdgroups_fetching_vec) {
x_block[lid] = x[lid + i];
}
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
scales_block[simd_gid * groups_per_block + j] = scales[i / groups + j];
}
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
biases_block[simd_gid * groups_per_block + j] = biases[i / groups + j];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load in_vec, scale, bias to registers
#pragma clang loop unroll(full)
for (int j=0; j<el_per_thread; j++) {
x_thread[j] = x_block[simd_lid*el_per_thread + j];
}
scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups];
bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups];
// Load the matrix elements
w_local = w[i / el_per_thread + simd_lid];
// Do all the work.
#pragma clang loop unroll(full)
for (int k=0; k<el_per_thread; k++) {
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
w_local >>= width;
}
}
// Accumulate in the simdgroup
result = simd_sum(result);
// Store the result
if (simd_lid == 0) {
y[out_row] = result;
}
}
template <typename T, const int BM, const int BK, const int BN, const int groups, const int width>
[[kernel]] void qmm_t(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& M [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& K [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
const uint lidy = lid / SIMD_SIZE;
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int bitmask = (1 << width) - 1;
constexpr int el_per_int = 32 / width;
constexpr int ints_per_block = BK / el_per_int;
constexpr int groups_per_block = (BK / groups > 0) ? (BK / groups) : 1;
constexpr int groups_per_simd = BN / (WM * WN);
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
// Using the kernel just as a type to instantiate the appropriate BlockMMA
// and constexpr size calculations
using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, true>;
using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>;
threadgroup T scales_block[BN * groups_per_block];
threadgroup T biases_block[BN * groups_per_block];
threadgroup T Xs[BM * BK];
threadgroup T Ws[BN * BK];
// Set the block
const int K_w = K / el_per_int;
const int K_g = K / groups;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
x += y_row * K;
w += y_col * K_w;
scales += y_col * K_g;
biases += y_col * K_g;
y += y_row * N + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
for (int k=0; k<K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the x tile
if (num_els < BM) {
loader_x.load_safe(short2(BK, num_els));
} else {
loader_x.load_unsafe();
}
// Load the scale and bias
if (simd_lid == 0) {
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd;
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd;
const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / groups;
const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / groups;
#pragma clang loop unroll(full)
for (int gs=0; gs<groups_per_simd; gs++) {
#pragma clang loop unroll(full)
for (int gc=0; gc<groups_per_block; gc++) {
scales_block_local[gc] = scales_local[gc];
biases_block_local[gc] = biases_local[gc];
}
scales_block_local += groups_per_block;
scales_local += K_g;
biases_block_local += groups_per_block;
biases_local += K_g;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the w tile
{
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BK / el_per_int);
int offset_col = offset % (BK / el_per_int);
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= width;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(Xs, Ws);
// Prepare for next iteration
loader_x.next();
w += ints_per_block;
// scales and biases cannot be advanced because they would have to be
// advanced every other iteration or sth.
}
// Store results to device memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (num_els < BM) {
mma_op.store_result_safe(y, N, short2(BN, num_els));
} else {
mma_op.store_result(y, N);
}
}
#define instantiate_qmv(name, itype, groups, width) \
template [[host_name("qmv_n_" #name "_groups_" #groups "_width_" #width)]] \
[[kernel]] void qmv<itype, 32, 32, groups, width>( \
const device uint32_t* w [[buffer(0)]], \
const device itype* scales [[buffer(1)]], \
const device itype* biases [[buffer(2)]], \
const device itype* x [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& in_vec_size [[buffer(5)]], \
const constant int& out_vec_size [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmv_types(groups, width) \
instantiate_qmv(float32, float, groups, width) \
instantiate_qmv(float16, half, groups, width) \
instantiate_qmv(bfloat16, bfloat16_t, groups, width)
instantiate_qmv_types(128, 2)
instantiate_qmv_types(128, 4)
instantiate_qmv_types(128, 8)
instantiate_qmv_types( 64, 2)
instantiate_qmv_types( 64, 4)
instantiate_qmv_types( 64, 8)
#define instantiate_qmm_t(name, itype, groups, width) \
template [[host_name("qmm_t_" #name "_groups_" #groups "_width_" #width)]] \
[[kernel]] void qmm_t<itype, 32, 64, 32, groups, width>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
const device itype* biases [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& M [[buffer(5)]], \
const constant int& N [[buffer(6)]], \
const constant int& K [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmm_t_types(groups, width) \
instantiate_qmm_t(float32, float, groups, width) \
instantiate_qmm_t(float16, half, groups, width) \
instantiate_qmm_t(bfloat16, bfloat16_t, groups, width)
instantiate_qmm_t_types(128, 2)
instantiate_qmm_t_types(128, 4)
instantiate_qmm_t_types(128, 8)
instantiate_qmm_t_types( 64, 2)
instantiate_qmm_t_types( 64, 4)
instantiate_qmm_t_types( 64, 8)

View File

@@ -0,0 +1,123 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [x_transposed, x_cols, x] = check_transpose(x_pre);
auto [w_transposed, w_cols, w] = check_transpose(w_pre);
auto [scales_transposed, scales_cols, scales] = check_transpose(scales_pre);
auto [biases_transposed, biases_cols, biases] = check_transpose(biases_pre);
if (!w_transposed) {
throw std::runtime_error("The quantized weight should be transposed.");
}
if (x_transposed || scales_transposed || biases_transposed) {
throw std::runtime_error("x, scales and biases should be row contiguous.");
}
int D = x.shape(-1);
int B = x.size() / D;
// Route to the qmv kernel
if (B == 1) {
std::ostringstream kname;
kname << "qmv_" << (w_transposed ? "n_" : "t_") << type_to_name(out)
<< "_groups_" << groups_ << "_width_" << width_;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int O = w.size() / w_cols;
int bo = 32;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
set_array_buffer(compute_encoder, w, 0);
set_array_buffer(compute_encoder, scales, 1);
set_array_buffer(compute_encoder, biases, 2);
set_array_buffer(compute_encoder, x, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmm kernel
else {
std::ostringstream kname;
kname << "qmm_" << (w_transposed ? "t_" : "n_") << type_to_name(out)
<< "_groups_" << groups_ << "_width_" << width_;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int O = w.size() / w_cols;
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 32;
int bk = 64;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, scales, 2);
set_array_buffer(compute_encoder, biases, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&B, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core