generalize lora finetuning for llama and mistral

This commit is contained in:
Awni Hannun
2023-12-09 14:13:55 -08:00
parent 46c6bbe0a1
commit b8332a1e66
5 changed files with 354 additions and 293 deletions

View File

@@ -1,53 +1,61 @@
# Copyright © 2023 Apple Inc.
import argparse
from itertools import starmap
import json
import numpy as np
from pathlib import Path
import shutil
import os
import torch
def map_torch_to_mlx(key, value):
if "tok_embedding" in key:
key = "embedding.weight"
elif "norm" in key:
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
key = key.replace("wq", "query_proj")
key = key.replace("wk", "key_proj")
key = key.replace("wv", "value_proj")
key = key.replace("wo", "out_proj")
elif "w1" in key or "w2" in key or "w3" in key:
# The FFN is a separate submodule in PyTorch
key = key.replace("feed_forward.w1", "linear1")
key = key.replace("feed_forward.w3", "linear2")
key = key.replace("feed_forward.w2", "linear3")
elif "output" in key:
key = key.replace("output", "out_proj")
elif "rope" in key:
return None, None
return (
key,
value.numpy()
if value.dtype != torch.bfloat16
else value.to(torch.float32).numpy(),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument("torch_weights")
parser.add_argument("output_file")
parser = argparse.ArgumentParser(
description="Convert Mistral or Llama models to MLX.",
)
parser.add_argument(
"--torch_model",
type=str,
default="mistral-7B-v0.1/",
help="The torch model directory",
)
parser.add_argument(
"--mlx_model",
type=str,
default="mlx-mistral-7B-v0.1/",
help="The directory to store the mlx model",
)
args = parser.parse_args()
state = torch.load(args.torch_weights)
torch_path = Path(args.torch_model)
if not os.path.exists(args.mlx_model):
os.makedirs(args.mlx_model)
mlx_path = Path(args.mlx_model)
state = torch.load(str(torch_path / "consolidated.00.pth"))
np.savez(
args.output_file,
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
str(mlx_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
)
# Copy the tokenizer
shutil.copyfile(
str(torch_path / "tokenizer.model"),
str(mlx_path / "tokenizer.model"),
)
# Copy the params
with open(torch_path / "params.json", "r") as f:
config = json.loads(f.read())
if "sliding_window" in config:
config.pop("sliding_window")
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape
with open(mlx_path / "params.json", "w") as outfile:
json.dump(config, outfile)