mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Improved quantized matrix vector product (#786)
This commit is contained in:
parent
cbcf44a4ca
commit
14b4e51a7c
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
@ -23,7 +23,143 @@ template <> struct AccT<bfloat16_t> {
|
||||
typedef float acc_t;
|
||||
};
|
||||
|
||||
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
||||
|
||||
template <typename T, typename U, int values_per_thread, int bits>
|
||||
inline U load_vector(const device T *x, thread U *x_thread) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U sum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < values_per_thread; i += 4) {
|
||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i+1] = x[i+1] / 4.0f;
|
||||
x_thread[i+2] = x[i+2] / 16.0f;
|
||||
x_thread[i+3] = x[i+3] / 64.0f;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
for (int i = 0; i < values_per_thread; i += 4) {
|
||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i+1] = x[i+1] / 16.0f;
|
||||
x_thread[i+2] = x[i+2] / 256.0f;
|
||||
x_thread[i+3] = x[i+3] / 4096.0f;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < values_per_thread; i++) {
|
||||
sum += x[i];
|
||||
x_thread[i] = x[i];
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U accum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
accum += (
|
||||
x_thread[4*i] * (w[i] & 0x03)
|
||||
+ x_thread[4*i+1] * (w[i] & 0x0c)
|
||||
+ x_thread[4*i+2] * (w[i] & 0x30)
|
||||
+ x_thread[4*i+3] * (w[i] & 0xc0));
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
accum += (
|
||||
x_thread[4*i] * (ws[i] & 0x000f)
|
||||
+ x_thread[4*i+1] * (ws[i] & 0x00f0)
|
||||
+ x_thread[4*i+2] * (ws[i] & 0x0f00)
|
||||
+ x_thread[4*i+3] * (ws[i] & 0xf000));
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < values_per_thread; i++) {
|
||||
accum += x_thread[i] * w[i];
|
||||
}
|
||||
}
|
||||
|
||||
return scale * accum + sum * bias;
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, int packs_per_thread>
|
||||
[[kernel]] void qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_simdgroup] = {0};
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup;
|
||||
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.z * out_vec_size + out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
|
||||
w += block_size / pack_factor;
|
||||
scales += block_size / group_size;
|
||||
biases += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
if (simd_lid == 0) {
|
||||
y[row] = static_cast<T>(result[row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@ -33,91 +169,101 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int packs_per_thread = 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||
|
||||
(void)lid;
|
||||
typedef float U;
|
||||
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_thread = 32 / bits;
|
||||
constexpr int colgroup = BN * el_per_thread;
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
|
||||
typedef typename AccT<T>::acc_t U;
|
||||
threadgroup U scales_block[BM * groups_per_block];
|
||||
threadgroup U biases_block[BM * groups_per_block];
|
||||
threadgroup U x_block[colgroup];
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread U result = 0;
|
||||
thread U scale = 1;
|
||||
thread U bias = 0;
|
||||
thread U x_thread[el_per_thread];
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_simdgroup] = {0};
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / el_per_thread;
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
int out_row = tid.y * BM + simd_gid;
|
||||
w += out_row * in_vec_size_w;
|
||||
scales += out_row * in_vec_size_g;
|
||||
biases += out_row * in_vec_size_g;
|
||||
x += tid.z * in_vec_size;
|
||||
y += tid.z * out_vec_size;
|
||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup;
|
||||
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
|
||||
|
||||
if (out_row >= out_vec_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Loop over in_vec in blocks of colgroup
|
||||
for (int i=0; i<in_vec_size; i+=colgroup) {
|
||||
// Load the vec to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<el_per_thread; j++) {
|
||||
x_block[simd_lid * el_per_thread + j] = x[i + simd_lid * el_per_thread + j];
|
||||
}
|
||||
}
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
scales_block[simd_gid * groups_per_block + j] = scales[i / group_size + j];
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
biases_block[simd_gid * groups_per_block + j] = biases[i / group_size + j];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// In this case we need to properly guard all our reads because there isn't
|
||||
// even 1 tile in the matrix
|
||||
if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
|
||||
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.z * out_vec_size + out_row;
|
||||
|
||||
// Load in_vec, scale, bias to registers
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<el_per_thread; j++) {
|
||||
x_thread[j] = x_block[simd_lid*el_per_thread + j];
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
|
||||
w += block_size / pack_factor;
|
||||
scales += block_size / group_size;
|
||||
biases += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size];
|
||||
bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size];
|
||||
|
||||
// Load the matrix elements
|
||||
w_local = w[i / el_per_thread + simd_lid];
|
||||
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_thread; k++) {
|
||||
result += (scale * static_cast<U>(w_local & bitmask) + bias) * x_thread[k];
|
||||
w_local >>= bits;
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
if (simd_lid == 0) {
|
||||
y[row] = static_cast<T>(result[row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate in the simdgroup
|
||||
result = simd_sum(result);
|
||||
// In this case the last tile is moved back to redo some output values
|
||||
else {
|
||||
w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
||||
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.z * out_vec_size + used_out_row;
|
||||
|
||||
// Store the result
|
||||
if (simd_lid == 0) {
|
||||
y[out_row] = static_cast<T>(result);
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
|
||||
w += block_size / pack_factor;
|
||||
scales += block_size / group_size;
|
||||
biases += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
if (simd_lid == 0) {
|
||||
y[row] = static_cast<T>(result[row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -532,9 +678,38 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
}
|
||||
|
||||
|
||||
#define instantiate_qmv(name, itype, group_size, bits) \
|
||||
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmv<itype, 32, 32, group_size, bits>( \
|
||||
#define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \
|
||||
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits "_fast")]] \
|
||||
[[kernel]] void qmv_fast<itype, group_size, bits, packs_per_thread>( \
|
||||
const device uint32_t* w [[buffer(0)]], \
|
||||
const device itype* scales [[buffer(1)]], \
|
||||
const device itype* biases [[buffer(2)]], \
|
||||
const device itype* x [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& in_vec_size [[buffer(5)]], \
|
||||
const constant int& out_vec_size [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \
|
||||
instantiate_qmv_fast(float32, float, group_size, bits, packs_per_thread) \
|
||||
instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \
|
||||
instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread)
|
||||
|
||||
instantiate_qmv_fast_types(128, 2, 1)
|
||||
instantiate_qmv_fast_types(128, 4, 2)
|
||||
instantiate_qmv_fast_types(128, 8, 2)
|
||||
instantiate_qmv_fast_types( 64, 2, 1)
|
||||
instantiate_qmv_fast_types( 64, 4, 2)
|
||||
instantiate_qmv_fast_types( 64, 8, 2)
|
||||
instantiate_qmv_fast_types( 32, 2, 1)
|
||||
instantiate_qmv_fast_types( 32, 4, 2)
|
||||
instantiate_qmv_fast_types( 32, 8, 2)
|
||||
|
||||
#define instantiate_qmv(name, itype, group_size, bits) \
|
||||
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmv<itype, group_size, bits>( \
|
||||
const device uint32_t* w [[buffer(0)]], \
|
||||
const device itype* scales [[buffer(1)]], \
|
||||
const device itype* biases [[buffer(2)]], \
|
||||
@ -543,7 +718,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
const constant int& in_vec_size [[buffer(5)]], \
|
||||
const constant int& out_vec_size [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
|
@ -41,8 +41,35 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int B = x.size() / D;
|
||||
int O = out.shape(-1);
|
||||
if (transpose_) {
|
||||
// Route to the fast qmv kernel that has no bounds checking
|
||||
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
std::ostringstream kname;
|
||||
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_ << "_fast";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, w, 0);
|
||||
set_array_buffer(compute_encoder, scales, 1);
|
||||
set_array_buffer(compute_encoder, biases, 2);
|
||||
set_array_buffer(compute_encoder, x, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmv kernel
|
||||
if (B < 6) {
|
||||
else if (B < 6) {
|
||||
std::ostringstream kname;
|
||||
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
@ -52,9 +79,9 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = std::min(32, O);
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, w, 0);
|
||||
|
@ -3032,6 +3032,15 @@ array quantized_matmul(
|
||||
}
|
||||
|
||||
auto dtype = result_type({x, scales, biases});
|
||||
if (!is_floating_point(dtype) || is_complex(dtype)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] Only real floating types are supported but "
|
||||
<< "the passed types where x.dtype() == " << x.dtype()
|
||||
<< ", scales.dtype() == " << scales.dtype()
|
||||
<< " and biases.dtype() == " << biases.dtype();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto out = array(
|
||||
{x.shape(0), w_outer_dims},
|
||||
dtype,
|
||||
|
Loading…
Reference in New Issue
Block a user