Improve names of quantization arguments (#235)

* Change the default quantization group_size to 64
* Rename groups to group_size and width to bits
This commit is contained in:
Angelos Katharopoulos
2023-12-20 16:53:53 -08:00
committed by GitHub
parent 57fe918cf8
commit b3916cbf2b
11 changed files with 184 additions and 180 deletions

View File

@@ -8,7 +8,7 @@ namespace mlx::core {
namespace {
template <typename T, int width, int groups>
template <typename T, int bits, int group_size>
void _qmm_t(
T* result,
const T* x,
@@ -18,10 +18,10 @@ void _qmm_t(
int M,
int N,
int K) {
constexpr int bitmask = (1 << width) - 1;
constexpr int pack_factor = 32 / width;
constexpr int packs_in_group = groups / pack_factor;
const int Kg = K / groups;
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) {
@@ -32,7 +32,7 @@ void _qmm_t(
for (int n = 0; n < N; n++) {
const T* x_local = x;
T sum = 0;
for (int k = 0; k < K; k += groups) {
for (int k = 0; k < K; k += group_size) {
T scale = *scales_local++;
T bias = *biases_local++;
@@ -42,7 +42,7 @@ void _qmm_t(
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= width;
wi >>= bits;
}
}
}
@@ -64,11 +64,11 @@ void _qmm_t_dispatch_typed(
int M,
int N,
int K,
int width,
int groups) {
switch (width) {
int group_size,
int bits) {
switch (bits) {
case 2: {
switch (groups) {
switch (group_size) {
case 64:
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
case 128:
@@ -76,7 +76,7 @@ void _qmm_t_dispatch_typed(
}
}
case 4: {
switch (groups) {
switch (group_size) {
case 64:
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
case 128:
@@ -84,7 +84,7 @@ void _qmm_t_dispatch_typed(
}
}
case 8: {
switch (groups) {
switch (group_size) {
case 64:
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
case 128:
@@ -93,9 +93,10 @@ void _qmm_t_dispatch_typed(
}
}
std::ostringstream msg;
msg << "Quantization type not supported. Provided bit width=" << width
<< " and groups=" << groups << ". The supported options are width in "
<< "{2, 4, 8} and groups in {64, 128}.";
msg << "Quantization type not supported. Provided bits=" << bits
<< " and group_size=" << group_size
<< ". The supported options are bits in "
<< "{2, 4, 8} and group_size in {64, 128}.";
throw std::invalid_argument(msg.str());
}
@@ -105,8 +106,8 @@ void _qmm_t_dispatch(
const array& w,
const array& scales,
const array& biases,
int width,
int groups) {
int bits,
int group_size) {
int K = x.shape(-1);
int M = x.size() / K;
int N = w.shape(1);
@@ -122,8 +123,8 @@ void _qmm_t_dispatch(
M,
N,
K,
width,
groups);
bits,
group_size);
break;
case float16:
_qmm_t_dispatch_typed<float16_t>(
@@ -135,8 +136,8 @@ void _qmm_t_dispatch(
M,
N,
K,
width,
groups);
bits,
group_size);
break;
case bfloat16:
_qmm_t_dispatch_typed<bfloat16_t>(
@@ -148,8 +149,8 @@ void _qmm_t_dispatch(
M,
N,
K,
width,
groups);
bits,
group_size);
break;
default:
throw std::invalid_argument(
@@ -177,7 +178,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_t_dispatch(out, x, w, scales, biases, width_, groups_);
_qmm_t_dispatch(out, x, w, scales, biases, group_size_, bits_);
}
} // namespace mlx::core