diff --git a/lora/convert.py b/lora/convert.py index 02dd06fb..cc29f7dc 100644 --- a/lora/convert.py +++ b/lora/convert.py @@ -47,6 +47,7 @@ if __name__ == "__main__": # Copy the params with open(torch_path / "params.json", "r") as f: config = json.loads(f.read()) + n_heads = config["n_heads"] if "sliding_window" in config: config.pop("sliding_window") if "n_kv_heads" not in config: