Improve names of quantization arguments (#235)

* Change the default quantization group_size to 64
* Rename groups to group_size and width to bits
This commit is contained in:
Angelos Katharopoulos
2023-12-20 16:53:53 -08:00
committed by GitHub
parent 57fe918cf8
commit b3916cbf2b
11 changed files with 184 additions and 180 deletions

View File

@@ -19,12 +19,12 @@ void _qmm_t_4_64(
int M,
int N,
int K) {
constexpr int width = 4;
constexpr int groups = 64;
constexpr int bitmask = (1 << width) - 1;
constexpr int pack_factor = 32 / width;
constexpr int packs_in_group = groups / pack_factor;
const int Kg = K / groups;
constexpr int bits = 4;
constexpr int group_size = 64;
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) {
@@ -35,7 +35,7 @@ void _qmm_t_4_64(
for (int n = 0; n < N; n++) {
const simd_float16* x_local = (simd_float16*)x;
simd_float16 sum = 0;
for (int k = 0; k < K; k += groups) {
for (int k = 0; k < K; k += group_size) {
float scale = *scales_local++;
float bias = *biases_local++;
@@ -46,7 +46,7 @@ void _qmm_t_4_64(
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= width;
wii >>= bits;
}
}
simd_float16 wf = simd_float(wi);
@@ -85,7 +85,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("x, scales and biases should be row contiguous.");
}
if (x.dtype() == float32 && width_ == 4 && groups_ == 64) {
if (x.dtype() == float32 && bits_ == 4 && group_size_ == 64) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
int K = x.shape(-1);
int M = x.size() / K;

View File

@@ -8,7 +8,7 @@ namespace mlx::core {
namespace {
template <typename T, int width, int groups>
template <typename T, int bits, int group_size>
void _qmm_t(
T* result,
const T* x,
@@ -18,10 +18,10 @@ void _qmm_t(
int M,
int N,
int K) {
constexpr int bitmask = (1 << width) - 1;
constexpr int pack_factor = 32 / width;
constexpr int packs_in_group = groups / pack_factor;
const int Kg = K / groups;
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) {
@@ -32,7 +32,7 @@ void _qmm_t(
for (int n = 0; n < N; n++) {
const T* x_local = x;
T sum = 0;
for (int k = 0; k < K; k += groups) {
for (int k = 0; k < K; k += group_size) {
T scale = *scales_local++;
T bias = *biases_local++;
@@ -42,7 +42,7 @@ void _qmm_t(
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= width;
wi >>= bits;
}
}
}
@@ -64,11 +64,11 @@ void _qmm_t_dispatch_typed(
int M,
int N,
int K,
int width,
int groups) {
switch (width) {
int group_size,
int bits) {
switch (bits) {
case 2: {
switch (groups) {
switch (group_size) {
case 64:
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
case 128:
@@ -76,7 +76,7 @@ void _qmm_t_dispatch_typed(
}
}
case 4: {
switch (groups) {
switch (group_size) {
case 64:
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
case 128:
@@ -84,7 +84,7 @@ void _qmm_t_dispatch_typed(
}
}
case 8: {
switch (groups) {
switch (group_size) {
case 64:
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
case 128:
@@ -93,9 +93,10 @@ void _qmm_t_dispatch_typed(
}
}
std::ostringstream msg;
msg << "Quantization type not supported. Provided bit width=" << width
<< " and groups=" << groups << ". The supported options are width in "
<< "{2, 4, 8} and groups in {64, 128}.";
msg << "Quantization type not supported. Provided bits=" << bits
<< " and group_size=" << group_size
<< ". The supported options are bits in "
<< "{2, 4, 8} and group_size in {64, 128}.";
throw std::invalid_argument(msg.str());
}
@@ -105,8 +106,8 @@ void _qmm_t_dispatch(
const array& w,
const array& scales,
const array& biases,
int width,
int groups) {
int bits,
int group_size) {
int K = x.shape(-1);
int M = x.size() / K;
int N = w.shape(1);
@@ -122,8 +123,8 @@ void _qmm_t_dispatch(
M,
N,
K,
width,
groups);
bits,
group_size);
break;
case float16:
_qmm_t_dispatch_typed<float16_t>(
@@ -135,8 +136,8 @@ void _qmm_t_dispatch(
M,
N,
K,
width,
groups);
bits,
group_size);
break;
case bfloat16:
_qmm_t_dispatch_typed<bfloat16_t>(
@@ -148,8 +149,8 @@ void _qmm_t_dispatch(
M,
N,
K,
width,
groups);
bits,
group_size);
break;
default:
throw std::invalid_argument(
@@ -177,7 +178,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_t_dispatch(out, x, w, scales, biases, width_, groups_);
_qmm_t_dispatch(out, x, w, scales, biases, group_size_, bits_);
}
} // namespace mlx::core

View File

@@ -14,7 +14,7 @@ using namespace metal;
MLX_MTL_CONST int SIMD_SIZE = 32;
template <typename T, const int BM, const int BN, const int groups, const int width>
template <typename T, const int BM, const int BN, const int group_size, const int bits>
[[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -30,10 +30,10 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
constexpr int bitmask = (1 << width) - 1;
constexpr int el_per_thread = 32 / width;
constexpr int bitmask = (1 << bits) - 1;
constexpr int el_per_thread = 32 / bits;
constexpr int colgroup = BN * el_per_thread;
constexpr int groups_per_block = colgroup / groups;
constexpr int groups_per_block = colgroup / group_size;
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
threadgroup T scales_block[BM * groups_per_block];
@@ -48,7 +48,7 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
// Adjust positions
const int in_vec_size_w = in_vec_size / el_per_thread;
const int in_vec_size_g = in_vec_size / groups;
const int in_vec_size_g = in_vec_size / group_size;
int out_row = tid.y * BM + simd_gid;
w += out_row * in_vec_size_w;
scales += out_row * in_vec_size_g;
@@ -66,11 +66,11 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
scales_block[simd_gid * groups_per_block + j] = scales[i / groups + j];
scales_block[simd_gid * groups_per_block + j] = scales[i / group_size + j];
}
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
biases_block[simd_gid * groups_per_block + j] = biases[i / groups + j];
biases_block[simd_gid * groups_per_block + j] = biases[i / group_size + j];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -80,8 +80,8 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
for (int j=0; j<el_per_thread; j++) {
x_thread[j] = x_block[simd_lid*el_per_thread + j];
}
scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups];
bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups];
scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size];
bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size];
// Load the matrix elements
w_local = w[i / el_per_thread + simd_lid];
@@ -90,7 +90,7 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
#pragma clang loop unroll(full)
for (int k=0; k<el_per_thread; k++) {
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
w_local >>= width;
w_local >>= bits;
}
}
@@ -104,7 +104,7 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
}
template <typename T, const int BM, const int BK, const int BN, const int groups, const int width>
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
[[kernel]] void qmm_t(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
@@ -126,10 +126,10 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int bitmask = (1 << width) - 1;
constexpr int el_per_int = 32 / width;
constexpr int bitmask = (1 << bits) - 1;
constexpr int el_per_int = 32 / bits;
constexpr int ints_per_block = BK / el_per_int;
constexpr int groups_per_block = (BK / groups > 0) ? (BK / groups) : 1;
constexpr int groups_per_block = (BK / group_size > 0) ? (BK / group_size) : 1;
constexpr int groups_per_simd = BN / (WM * WN);
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
@@ -145,7 +145,7 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
// Set the block
const int K_w = K / el_per_int;
const int K_g = K / groups;
const int K_g = K / group_size;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
x += y_row * K;
@@ -172,8 +172,8 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
if (simd_lid == 0) {
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd;
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd;
const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / groups;
const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / groups;
const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / group_size;
const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / group_size;
#pragma clang loop unroll(full)
for (int gs=0; gs<groups_per_simd; gs++) {
#pragma clang loop unroll(full)
@@ -199,13 +199,13 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)];
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= width;
wi >>= bits;
}
}
}
@@ -231,9 +231,9 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
}
#define instantiate_qmv(name, itype, groups, width) \
template [[host_name("qmv_n_" #name "_groups_" #groups "_width_" #width)]] \
[[kernel]] void qmv<itype, 32, 32, groups, width>( \
#define instantiate_qmv(name, itype, group_size, bits) \
template [[host_name("qmv_n_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmv<itype, 32, 32, group_size, bits>( \
const device uint32_t* w [[buffer(0)]], \
const device itype* scales [[buffer(1)]], \
const device itype* biases [[buffer(2)]], \
@@ -246,10 +246,10 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmv_types(groups, width) \
instantiate_qmv(float32, float, groups, width) \
instantiate_qmv(float16, half, groups, width) \
instantiate_qmv(bfloat16, bfloat16_t, groups, width)
#define instantiate_qmv_types(group_size, bits) \
instantiate_qmv(float32, float, group_size, bits) \
instantiate_qmv(float16, half, group_size, bits) \
instantiate_qmv(bfloat16, bfloat16_t, group_size, bits)
instantiate_qmv_types(128, 2)
instantiate_qmv_types(128, 4)
@@ -258,9 +258,9 @@ instantiate_qmv_types( 64, 2)
instantiate_qmv_types( 64, 4)
instantiate_qmv_types( 64, 8)
#define instantiate_qmm_t(name, itype, groups, width) \
template [[host_name("qmm_t_" #name "_groups_" #groups "_width_" #width)]] \
[[kernel]] void qmm_t<itype, 32, 64, 32, groups, width>( \
#define instantiate_qmm_t(name, itype, group_size, bits) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
@@ -274,10 +274,10 @@ instantiate_qmv_types( 64, 8)
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmm_t_types(groups, width) \
instantiate_qmm_t(float32, float, groups, width) \
instantiate_qmm_t(float16, half, groups, width) \
instantiate_qmm_t(bfloat16, bfloat16_t, groups, width)
#define instantiate_qmm_t_types(group_size, bits) \
instantiate_qmm_t(float32, float, group_size, bits) \
instantiate_qmm_t(float16, half, group_size, bits) \
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits)
instantiate_qmm_t_types(128, 2)
instantiate_qmm_t_types(128, 4)

View File

@@ -58,7 +58,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
if (B == 1) {
std::ostringstream kname;
kname << "qmv_" << (w_transposed ? "n_" : "t_") << type_to_name(out)
<< "_groups_" << groups_ << "_width_" << width_;
<< "_gs_" << group_size_ << "_b_" << bits_;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
@@ -87,7 +87,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
else {
std::ostringstream kname;
kname << "qmm_" << (w_transposed ? "t_" : "n_") << type_to_name(out)
<< "_groups_" << groups_ << "_width_" << width_;
<< "_gs_" << group_size_ << "_b_" << bits_;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);

View File

@@ -2583,8 +2583,8 @@ array quantized_matmul(
const array& w,
const array& scales,
const array& biases,
int groups /* = 128 */,
int width /* = 4 */,
int group_size /* = 64 */,
int bits /* = 4 */,
StreamOrDevice s /* = {} */) {
auto x = in_x;
@@ -2611,24 +2611,25 @@ array quantized_matmul(
x = reshape(x, {-1, x_inner_dims}, s);
}
int w_inner_dims = w.shape(0) * (32 / width);
int w_inner_dims = w.shape(0) * (32 / bits);
if (w_inner_dims != x_inner_dims) {
std::ostringstream msg;
msg << "[quantized_matmul] Last dimension of first input with "
<< "shape (..., " << x_inner_dims
<< ") does not match the expanded first "
<< "dimension of the quantized matrix " << w_inner_dims
<< ", computed from shape " << w.shape() << " with groups=" << groups
<< " and width=" << width;
<< ", computed from shape " << w.shape()
<< " with group_size=" << group_size << " and bits=" << bits;
throw std::invalid_argument(msg.str());
}
int n_groups = x_inner_dims / groups;
int n_groups = x_inner_dims / group_size;
if (scales.shape(-1) != n_groups || biases.shape(-1) != n_groups) {
std::ostringstream msg;
msg << "[quantized_matmul] Scales and biases provided do not match the "
<< "quantization arguments (groups=" << groups << ", width=" << width
<< "). Expected shapes (" << w.shape(1) << ", " << x_inner_dims / groups
<< "quantization arguments (group_size=" << group_size
<< ", bits=" << bits << "). Expected shapes (" << w.shape(1) << ", "
<< x_inner_dims / group_size
<< "), but got scales.shape=" << scales.shape()
<< " and biases.shape=" << biases.shape();
throw std::invalid_argument(msg.str());
@@ -2637,7 +2638,7 @@ array quantized_matmul(
auto out = array(
{x.shape(0), w.shape(1)},
x.dtype(),
std::make_unique<QuantizedMatmul>(to_stream(s), groups, width),
std::make_unique<QuantizedMatmul>(to_stream(s), group_size, bits),
{x, w, scales, biases});
// If needed reshape x to the original batch shape
@@ -2651,8 +2652,8 @@ array quantized_matmul(
std::tuple<array, array, array> quantize(
const array& w,
int groups /* = 128 */,
int width /* = 4 */,
int group_size /* = 64 */,
int bits /* = 4 */,
StreamOrDevice s /* = {} */) {
if (w.ndim() != 2) {
throw std::invalid_argument("[quantize] Only matrices supported for now");
@@ -2663,23 +2664,24 @@ std::tuple<array, array, array> quantize(
"[quantize] All dimensions should be divisible by 32 for now");
}
if ((w.shape(-1) % groups) != 0) {
if ((w.shape(-1) % group_size) != 0) {
std::ostringstream msg;
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
<< "the quantization group size " << groups
<< ". However the provided matrix"
<< " has shape " << w.shape();
<< "the quantization group size " << group_size
<< ". However the provided "
<< " matrix has shape " << w.shape();
throw std::invalid_argument(msg.str());
}
// Compute some constants used for the quantization
int n_bins = (1 << width) - 1; // 2**width - 1
int el_per_int = 32 / width;
array shifts = power(array(2, uint32), arange(0, 32, width, uint32, s), s);
int n_bins = (1 << bits) - 1; // 2**bits - 1
int el_per_int = 32 / bits;
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
shifts = reshape(shifts, {1, 1, -1}, s);
// Compute scales and biases
array packed_w = reshape(w, {w.shape(0), w.shape(1) / groups, groups}, s);
array packed_w =
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array delta = divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s);
@@ -2700,8 +2702,8 @@ array dequantize(
const array& w,
const array& scales,
const array& biases,
int groups /* = 128 */,
int width /* = 4 */,
int group_size /* = 64 */,
int bits /* = 4 */,
StreamOrDevice s /* = {} */) {
if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) {
throw std::invalid_argument("[dequantize] Only matrices supported for now");
@@ -2723,22 +2725,22 @@ array dequantize(
}
// Compute some constants for the dequantization
int el_per_int = 32 / width;
int el_per_int = 32 / bits;
if (w.shape(1) * el_per_int != scales.shape(1) * groups) {
if (w.shape(1) * el_per_int != scales.shape(1) * group_size) {
std::ostringstream msg;
msg << "[dequantize] Shape of scales and biases does not match the matrix "
<< "given the quantization parameters. Provided matrix of shape "
<< w.shape() << " and scales/biases of shape " << scales.shape()
<< " with groups=" << groups << " and width=" << width << ".";
<< " with group_size=" << group_size << " and bits=" << bits << ".";
throw std::invalid_argument(msg.str());
}
// Extract the pieces from the passed quantized matrix
std::vector<array> parts;
for (int start = 0; start < 32; start += width) {
for (int start = 0; start < 32; start += bits) {
// TODO: Implement bitwise operators for integral types
int shift_left = 32 - (start + width);
int shift_left = 32 - (start + bits);
int shift_right = shift_left + start;
array p = multiply(w, array(1 << shift_left, uint32), s);
p = floor_divide(p, array(1 << shift_right, uint32), s);
@@ -2748,7 +2750,7 @@ array dequantize(
array w_full = concatenate(parts, -1, s);
// Dequantize
w_full = reshape(w_full, {w.shape(0), -1, groups}, s);
w_full = reshape(w_full, {w.shape(0), -1, group_size}, s);
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
w_full = add(w_full, expand_dims(biases, -1, s), s);
w_full = reshape(w_full, {w.shape(0), -1}, s);

View File

@@ -1037,15 +1037,15 @@ array quantized_matmul(
const array& w,
const array& scales,
const array& biases,
int groups = 128,
int width = 4,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});
/** Quantize a matrix along its last axis */
std::tuple<array, array, array> quantize(
const array& w,
int groups = 128,
int width = 4,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});
/** Dequantize a matrix produced by quantize() */
@@ -1053,8 +1053,8 @@ array dequantize(
const array& w,
const array& scales,
const array& biases,
int groups = 128,
int width = 4,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});
} // namespace mlx::core

View File

@@ -1718,7 +1718,7 @@ array QuantizedMatmul::jvp(
bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
const QuantizedMatmul& qm_other = static_cast<const QuantizedMatmul&>(other);
return groups_ == qm_other.groups_ && width_ == qm_other.width_;
return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_;
}
std::pair<array, int> RandomBits::vmap(

View File

@@ -1112,8 +1112,8 @@ class Power : public Primitive {
class QuantizedMatmul : public Primitive {
public:
explicit QuantizedMatmul(Stream stream, int groups, int width)
: Primitive(stream), groups_(groups), width_(width){};
explicit QuantizedMatmul(Stream stream, int group_size, int bits)
: Primitive(stream), group_size_(group_size), bits_(bits){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@@ -1127,8 +1127,8 @@ class QuantizedMatmul : public Primitive {
bool is_equivalent(const Primitive& other) const override;
private:
int groups_;
int width_;
int group_size_;
int bits_;
void eval(const std::vector<array>& inputs, array& out);
};