Allow arbitrary first dimension in quantization kernels. (#458)

* Allow arbitrary first dim on qmm_t and qmv
* Allow arbitrary first dim on qmm and qvm
* Specialized aligned vs unaligned case
* Add more checks for valid quantizations
This commit is contained in:
Angelos Katharopoulos
2024-01-16 00:46:21 -08:00
committed by GitHub
parent f44c132f4a
commit c15fe3e61b
6 changed files with 206 additions and 66 deletions

View File

@@ -154,10 +154,13 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=BM) {
int offset = simd_lid + i;
bool thread_in_bounds = offset < in_vec_size;
// Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid == 0) {
x_block[simd_lid] = x[simd_lid + i];
x_block[simd_lid] = (thread_in_bounds) ? x[offset] : 0;
}
// Load the scales and biases to shared memory
@@ -180,7 +183,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
// Load the matrix elements
w_local = w[(i + simd_lid) * out_vec_size_w];
w_local = (thread_in_bounds) ? w[offset * out_vec_size_w] : 0;
// Do all the work.
#pragma clang loop unroll(full)
@@ -206,7 +209,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
}
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits, const bool aligned_N>
[[kernel]] void qmm_t(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
@@ -257,6 +260,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
const short num_outs = min(BN, N - y_col);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
@@ -292,21 +296,48 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Load the w tile
{
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BK / el_per_int);
int offset_col = offset % (BK / el_per_int);
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
if (!aligned_N && num_outs < BN) {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BK / el_per_int);
int offset_col = offset % (BK / el_per_int);
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
if (y_col + offset_col < N) {
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
} else {
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = 0;
}
}
}
} else {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BK / el_per_int);
int offset_col = offset % (BK / el_per_int);
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
}
}
}
@@ -324,8 +355,8 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Store results to device memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (num_els < BM) {
mma_op.store_result_safe(y, N, short2(BN, num_els));
if (num_els < BM || num_outs < BN) {
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
} else {
mma_op.store_result(y, N);
}
@@ -417,21 +448,48 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Load the w tile
{
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);
int offset_col = offset % (BN / el_per_int);
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
if (k + BK >= K) {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);
int offset_col = offset % (BN / el_per_int);
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
if (y_row + offset_row < K) {
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
} else {
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = 0;
}
}
}
} else {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);
int offset_col = offset % (BN / el_per_int);
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
}
}
}
@@ -511,9 +569,9 @@ instantiate_qvm_types( 64, 2)
instantiate_qvm_types( 64, 4)
instantiate_qvm_types( 64, 8)
#define instantiate_qmm_t(name, itype, group_size, bits) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits>( \
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits, aligned_N>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
@@ -528,9 +586,12 @@ instantiate_qvm_types( 64, 8)
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmm_t_types(group_size, bits) \
instantiate_qmm_t(float32, float, group_size, bits) \
instantiate_qmm_t(float16, half, group_size, bits) \
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits)
instantiate_qmm_t(float32, float, group_size, bits, false) \
instantiate_qmm_t(float16, half, group_size, bits, false) \
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \
instantiate_qmm_t(float32, float, group_size, bits, true) \
instantiate_qmm_t(float16, half, group_size, bits, true) \
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true)
instantiate_qmm_t_types(128, 2)
instantiate_qmm_t_types(128, 4)

View File

@@ -52,7 +52,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = 32;
int bo = std::min(32, O);
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
@@ -72,7 +72,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
else {
std::ostringstream kname;
kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
<< bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
@@ -85,7 +85,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int bn = 32;
int bk = 64;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);
@@ -110,10 +110,10 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = 32;
int bo = std::min(32, O);
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, (w.shape(1) + bo - 1) / bo, B);
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);

View File

@@ -2801,16 +2801,25 @@ std::tuple<array, array, array> quantize(
int group_size /* = 64 */,
int bits /* = 4 */,
StreamOrDevice s /* = {} */) {
if (group_size != 64 && group_size != 128) {
std::ostringstream msg;
msg << "[quantize] The requested group size " << group_size
<< " is not supported. The supported group sizes are 64 and 128.";
throw std::invalid_argument(msg.str());
}
if (bits != 2 && bits != 4 && bits != 8) {
std::ostringstream msg;
msg << "[quantize] The requested number of bits " << bits
<< " is not supported. The supported bits are 2, 4 and 8.";
throw std::invalid_argument(msg.str());
}
if (w.ndim() != 2) {
throw std::invalid_argument("[quantize] Only matrices supported for now");
}
if ((w.shape(0) % 32) != 0) {
throw std::invalid_argument(
"[quantize] All dimensions should be divisible by 32 for now");
}
if ((w.shape(-1) % group_size) != 0) {
if ((w.shape(1) % group_size) != 0) {
std::ostringstream msg;
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
<< "the quantization group size " << group_size
@@ -2825,6 +2834,20 @@ std::tuple<array, array, array> quantize(
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
shifts = reshape(shifts, {1, 1, -1}, s);
// Check that the w matrix will fill up a whole SIMD.
// This is an implementation detail which should be removed in the future but
// at least we bail out early which will result in a nice readable error.
//
// Hopefully nobody is quantizing matrices that small anyway.
if (w.shape(1) < 32 * el_per_int) {
std::ostringstream msg;
msg << "[quantize] The feature dimension (2nd dimension of the matrix) is "
<< "too small for quantization. We support >=512 for 2 bits, "
<< ">= 256 for 4 bits and >= 128 for 8 bits. The provided matrix has "
<< "shape " << w.shape() << ".";
throw std::invalid_argument(msg.str());
}
// Compute scales and biases
array packed_w =
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
@@ -2855,11 +2878,6 @@ array dequantize(
throw std::invalid_argument("[dequantize] Only matrices supported for now");
}
if ((w.shape(0) % 32) != 0) {
throw std::invalid_argument(
"[dequantize] All dimensions should be divisible by 32 for now");
}
if (w.shape(0) != scales.shape(0) || w.shape(0) != biases.shape(0)) {
throw std::invalid_argument(
"[dequantize] Shape of scales and biases does not match the matrix");