Compare subcommand.

This commit is contained in:
antirez
2023-12-28 17:24:05 +01:00
parent 2a599dc5d0
commit 54946cbf14
2 changed files with 66 additions and 0 deletions

View File

@@ -4,6 +4,7 @@
#include <string.h>
#include <assert.h>
#include <errno.h>
#include <math.h>
#include "gguflib.h"
#include "sds.h"
@@ -372,6 +373,60 @@ void gguf_tools_inspect_weights(const char *filename, const char *tname, uint64_
return;
}
/* ========================== 'compare' subcommand ========================== */
/* Given two tensors of the same length, return the average difference
* of their weights. */
double tensors_avg_diff(gguf_tensor *t1, gguf_tensor *t2) {
float *weights1 = gguf_tensor_to_float(t1);
float *weights2 = gguf_tensor_to_float(t2);
if (weights1 == NULL || weights2 == NULL) {
perror("Error while decoding tensor weights");
exit(1);
}
double tot_diff = 0;
for (uint64_t j = 0; j < t1->num_weights; j++)
tot_diff += fabs(weights1[j]-weights2[j]);
free(weights1);
free(weights2);
return tot_diff/t1->num_weights;
}
void gguf_tools_compare(const char *file1, const char *file2) {
gguf_ctx *ctx1 = gguf_init(file1);
gguf_ctx *ctx2 = gguf_init(file2);
if (ctx1 == NULL || ctx2 == NULL) {
perror("Opening GGUF files");
exit(1);
}
/* Skip all the key-value pairs. */
gguf_skip_key_values_section(ctx1);
/* For each tensor of the first net... */
gguf_tensor tensor1, tensor2;
while (gguf_get_tensor(ctx1,&tensor1)) {
gguf_skip_key_values_section(ctx2);
while (gguf_get_tensor(ctx2,&tensor2)) {
/* Search for a tensor with the same name. */
if (tensor2.namelen == tensor1.namelen &&
memcmp(tensor2.name,tensor1.name,tensor1.namelen) == 0)
{
printf("%.*s: ", (int)tensor1.namelen, tensor1.name);
fflush(stdout);
if (tensor1.num_weights != tensor2.num_weights) {
printf("size mismatch\n");
} else {
printf("avg weights difference: %f\n",
tensors_avg_diff(&tensor1, &tensor2));
}
}
}
gguf_rewind(ctx2);
}
}
/* ======================= Main and CLI options parsing ===================== */
void gguf_tools_usage(const char *progname) {
@@ -379,6 +434,7 @@ void gguf_tools_usage(const char *progname) {
"Subcommands:\n"
" show <filename> -- show GGUF model keys and tensors.\n"
" inspect-tensor <filename> <tensor-name> [count] -- show tensor weights.\n"
" compare <file1> <file2> -- avg weights diff for matching tensor names.\n"
" split-mixtral <ids...> mixtral.gguf out.gguf -- extract expert.\n"
"Example:\n"
" split-mixtral 65230776370407150546470161412165 mixtral.gguf out.gguf\n"
@@ -391,6 +447,8 @@ int main(int argc, char **argv) {
if (!strcmp(argv[1],"show") && argc == 3) {
gguf_tools_show(argv[2]);
} else if (!strcmp(argv[1],"compare") && argc == 4) {
gguf_tools_compare(argv[2],argv[3]);
} else if (!strcmp(argv[1],"inspect-tensor") && (argc == 4 || argc == 5)) {
gguf_tools_inspect_weights(argv[2],argv[3],
argc == 5 ? atoi(argv[4]) : 0);