Q4_K dequantization.

This commit is contained in:
antirez
2023-12-28 11:25:08 +01:00
parent c25ccfa02a
commit e2062eea2c
3 changed files with 107 additions and 23 deletions

View File

@@ -511,9 +511,11 @@ float *gguf_tensor_to_float(gguf_tensor *tensor) {
gguf_get_tensor_type_features(tensor->type);
uint64_t block_size = tf->bytes_per_block;
float *f = malloc(tensor->num_weights*sizeof(float));
if (tensor->type == GUFF_TYPE_Q8_0) {
if (tensor->type == GGUF_TYPE_Q8_0) {
/* Very simple layout: |16 bit delta|32 x 8bit weights|
* Each weight is delta * quantized_weight[0..31] */
int8_t *block = (int8_t*)tensor->weights_data;
uint64_t i = 0;
uint64_t i = 0; // i-th weight to dequantize.
while(i < tensor->num_weights) {
/* For each block get the delta and convert all the
* weights in the block. */
@@ -524,6 +526,79 @@ float *gguf_tensor_to_float(gguf_tensor *tensor) {
}
block += block_size; // Go to the next block.
}
} else if (tensor->type == GGUF_TYPE_Q4_K) {
uint8_t *block = (uint8_t*)tensor->weights_data;
uint64_t i = 0; // i-th weight to dequantize.
while(i < tensor->num_weights) {
/* Q4_K super-blocks have 256 total weights, split in 8 sub-block.
* Each 8 sub-blocks have a different set of deltas/mins, so
* there are 16 total values for deltas/mins, but the deltas/mins
* are also quantized (6 bits each) using two different deltas:
* delta_of_deltas and delta_of_mins, that are two FP16 values
* at the start of the super block, so:
*
* |FP16 d_of_deltas | +
* |FP16 d_of_mins | +
* |16 6 bit integers d,m pairs, one per sub-block of 32 ele | +
* |256 x 4bit weights|
*/
float deltas_delta = from_half(*((uint16_t*)block));
float mins_delta = from_half(*((uint16_t*)(block+2)));
block += 4;
/* Extract the 16 x 6 bit values deltas-mins pairs. The
* encoding of those values is odd because of performance
* reasons:
*
* dddddddd dddddddd dddddddd dddddddd mmmmmmmm mmmmmmmm
* 44000000|55111111|66222222|77333333|44000000|55111111
*
* mmmmmmmm mmmmmmmm mmmmdddd mmmmdddd mmmmdddd mmmmdddd
* 66222222|77333333|44444444|55555555|66666666|77777777
*
* In the above diagram you can see the 12 bytes and the
* deltas/mins 6 bits encodings. */
/* Scale deltas/mins. */
float deltas[8], mins[8];
for (int j = 0; j < 8; j++) {
uint8_t d,m;
if (j < 4) {
d = block[j] & 63;
m = block[j+4] & 63;
} else {
d = (block[j+4] & 0xF) | ((block[j-4] >> 6) << 4);
m = (block[j+4] >> 4) | ((block[j-0] >> 6) << 4);
}
deltas[j] = d * deltas_delta;
mins[j] = m * mins_delta;
}
block += 12; // Seek 4-bit weights start.
/* Finally we can extract the 256 weights.
* We process two blocks per time, because each
* 32 bytes have 64 weights stored like this:
* First 32 weights of the first block are the higher 4
* bits of each byte. Second 32 weights of the second
* block are lower 4 bits of each byte. */
for (uint32_t b = 0; b < 8; b += 2) {
float delta = deltas[b];
float min = mins[b];
/* First set: higher bits. */
for (uint32_t j = 0; j < 32; j++) {
uint8_t w = block[j] & 0xf;
f[i++] = w * delta - min;
if (i == tensor->num_weights) return f;
}
/* Second set: lower bits. */
for (uint32_t j = 0; j < 32; j++) {
uint8_t w = block[j] >> 4;
f[i++] = w * delta - min;
if (i == tensor->num_weights) return f;
}
block += 32; // Skip the two processed blocks.
}
}
} else {
errno = EINVAL;
return NULL;