add lm eval version

This commit is contained in:
Alex Barron 2024-12-06 16:21:20 -08:00
parent dc56226bf8
commit 80590f5ec2

View File

@ -4,10 +4,10 @@ import argparse
import json import json
import logging import logging
import os import os
from importlib.metadata import version
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import lm_eval
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
@ -340,10 +340,9 @@ def main():
fewshot_random_seed=args.seed, fewshot_random_seed=args.seed,
) )
filename = f"eval_{args.model.replace('/', '_')}_{('_'.join(args.tasks))}_{args.num_shots:02d}.json" filename = f"eval_{args.model.replace('/', '_')}_{('_'.join(args.tasks))}_{args.num_shots:02d}_v_{version('lm_eval')}.json"
output_path = Path(args.output_dir) / filename output_path = Path(args.output_dir) / filename
output_path.write_text(json.dumps(results["results"], indent=4)) output_path.write_text(json.dumps(results["results"], indent=4))
print("Results:") print("Results:")
for result in results["results"].values(): for result in results["results"].values():
print(json.dumps(result, indent=4)) print(json.dumps(result, indent=4))
print()