mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Improve names of quantization arguments (#235)
* Change the default quantization group_size to 64 * Rename groups to group_size and width to bits
This commit is contained in:
committed by
GitHub
parent
57fe918cf8
commit
b3916cbf2b
@@ -8,7 +8,7 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, int width, int groups>
|
||||
template <typename T, int bits, int group_size>
|
||||
void _qmm_t(
|
||||
T* result,
|
||||
const T* x,
|
||||
@@ -18,10 +18,10 @@ void _qmm_t(
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << width) - 1;
|
||||
constexpr int pack_factor = 32 / width;
|
||||
constexpr int packs_in_group = groups / pack_factor;
|
||||
const int Kg = K / groups;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
const int Kg = K / group_size;
|
||||
const int Kw = K / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
@@ -32,7 +32,7 @@ void _qmm_t(
|
||||
for (int n = 0; n < N; n++) {
|
||||
const T* x_local = x;
|
||||
T sum = 0;
|
||||
for (int k = 0; k < K; k += groups) {
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
T scale = *scales_local++;
|
||||
T bias = *biases_local++;
|
||||
|
||||
@@ -42,7 +42,7 @@ void _qmm_t(
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||
wi >>= width;
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -64,11 +64,11 @@ void _qmm_t_dispatch_typed(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int width,
|
||||
int groups) {
|
||||
switch (width) {
|
||||
int group_size,
|
||||
int bits) {
|
||||
switch (bits) {
|
||||
case 2: {
|
||||
switch (groups) {
|
||||
switch (group_size) {
|
||||
case 64:
|
||||
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||
case 128:
|
||||
@@ -76,7 +76,7 @@ void _qmm_t_dispatch_typed(
|
||||
}
|
||||
}
|
||||
case 4: {
|
||||
switch (groups) {
|
||||
switch (group_size) {
|
||||
case 64:
|
||||
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||
case 128:
|
||||
@@ -84,7 +84,7 @@ void _qmm_t_dispatch_typed(
|
||||
}
|
||||
}
|
||||
case 8: {
|
||||
switch (groups) {
|
||||
switch (group_size) {
|
||||
case 64:
|
||||
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||
case 128:
|
||||
@@ -93,9 +93,10 @@ void _qmm_t_dispatch_typed(
|
||||
}
|
||||
}
|
||||
std::ostringstream msg;
|
||||
msg << "Quantization type not supported. Provided bit width=" << width
|
||||
<< " and groups=" << groups << ". The supported options are width in "
|
||||
<< "{2, 4, 8} and groups in {64, 128}.";
|
||||
msg << "Quantization type not supported. Provided bits=" << bits
|
||||
<< " and group_size=" << group_size
|
||||
<< ". The supported options are bits in "
|
||||
<< "{2, 4, 8} and group_size in {64, 128}.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@@ -105,8 +106,8 @@ void _qmm_t_dispatch(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int width,
|
||||
int groups) {
|
||||
int bits,
|
||||
int group_size) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.size() / K;
|
||||
int N = w.shape(1);
|
||||
@@ -122,8 +123,8 @@ void _qmm_t_dispatch(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
width,
|
||||
groups);
|
||||
bits,
|
||||
group_size);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_t_dispatch_typed<float16_t>(
|
||||
@@ -135,8 +136,8 @@ void _qmm_t_dispatch(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
width,
|
||||
groups);
|
||||
bits,
|
||||
group_size);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_t_dispatch_typed<bfloat16_t>(
|
||||
@@ -148,8 +149,8 @@ void _qmm_t_dispatch(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
width,
|
||||
groups);
|
||||
bits,
|
||||
group_size);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
@@ -177,7 +178,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
_qmm_t_dispatch(out, x, w, scales, biases, width_, groups_);
|
||||
_qmm_t_dispatch(out, x, w, scales, biases, group_size_, bits_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user