Compare subcommand: just skip tensors we can't yet dequantize.

This commit is contained in:
antirez
2023-12-30 10:13:38 +01:00
parent 400f60b75b
commit 3663d73c22

View File

@@ -376,13 +376,15 @@ void gguf_tools_inspect_weights(const char *filename, const char *tname, uint64_
/* ========================== 'compare' subcommand ========================== */
/* Given two tensors of the same length, return the average difference
* of their weights. */
double tensors_avg_diff(gguf_tensor *t1, gguf_tensor *t2) {
* of their weights. Returns 1 on success, 0 if one or both the provided
* tensors can't be dequantized. */
int tensors_avg_diff(gguf_tensor *t1, gguf_tensor *t2, double *diff) {
float *weights1 = gguf_tensor_to_float(t1);
float *weights2 = gguf_tensor_to_float(t2);
if (weights1 == NULL || weights2 == NULL) {
perror("Error while decoding tensor weights");
exit(1);
if (weights1) free(weights1);
if (weights2) free(weights2);
return 0;
}
double tot_diff = 0;
@@ -390,7 +392,8 @@ double tensors_avg_diff(gguf_tensor *t1, gguf_tensor *t2) {
tot_diff += fabs(weights1[j]-weights2[j]);
free(weights1);
free(weights2);
return tot_diff/t1->num_weights;
*diff = tot_diff/t1->num_weights;
return 1;
}
void gguf_tools_compare(const char *file1, const char *file2) {
@@ -418,8 +421,12 @@ void gguf_tools_compare(const char *file1, const char *file2) {
if (tensor1.num_weights != tensor2.num_weights) {
printf("size mismatch\n");
} else {
printf("avg weights difference: %f\n",
tensors_avg_diff(&tensor1, &tensor2));
double diff;
if (tensors_avg_diff(&tensor1, &tensor2, &diff)) {
printf("avg weights difference: %f\n", diff);
} else {
printf("dequantization function missing...\n");
}
}
}
}