mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
generalize lora finetuning for llama and mistral
This commit is contained in:
@@ -1,53 +1,61 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from itertools import starmap
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import os
|
||||
import torch
|
||||
|
||||
|
||||
def map_torch_to_mlx(key, value):
|
||||
if "tok_embedding" in key:
|
||||
key = "embedding.weight"
|
||||
|
||||
elif "norm" in key:
|
||||
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
|
||||
|
||||
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
|
||||
key = key.replace("wq", "query_proj")
|
||||
key = key.replace("wk", "key_proj")
|
||||
key = key.replace("wv", "value_proj")
|
||||
key = key.replace("wo", "out_proj")
|
||||
|
||||
elif "w1" in key or "w2" in key or "w3" in key:
|
||||
# The FFN is a separate submodule in PyTorch
|
||||
key = key.replace("feed_forward.w1", "linear1")
|
||||
key = key.replace("feed_forward.w3", "linear2")
|
||||
key = key.replace("feed_forward.w2", "linear3")
|
||||
|
||||
elif "output" in key:
|
||||
key = key.replace("output", "out_proj")
|
||||
|
||||
elif "rope" in key:
|
||||
return None, None
|
||||
|
||||
return (
|
||||
key,
|
||||
value.numpy()
|
||||
if value.dtype != torch.bfloat16
|
||||
else value.to(torch.float32).numpy(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument("torch_weights")
|
||||
parser.add_argument("output_file")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Mistral or Llama models to MLX.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--torch_model",
|
||||
type=str,
|
||||
default="mistral-7B-v0.1/",
|
||||
help="The torch model directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx_model",
|
||||
type=str,
|
||||
default="mlx-mistral-7B-v0.1/",
|
||||
help="The directory to store the mlx model",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
state = torch.load(args.torch_weights)
|
||||
torch_path = Path(args.torch_model)
|
||||
if not os.path.exists(args.mlx_model):
|
||||
os.makedirs(args.mlx_model)
|
||||
mlx_path = Path(args.mlx_model)
|
||||
|
||||
state = torch.load(str(torch_path / "consolidated.00.pth"))
|
||||
np.savez(
|
||||
args.output_file,
|
||||
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
|
||||
str(mlx_path / "weights.npz"),
|
||||
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
|
||||
)
|
||||
|
||||
# Copy the tokenizer
|
||||
shutil.copyfile(
|
||||
str(torch_path / "tokenizer.model"),
|
||||
str(mlx_path / "tokenizer.model"),
|
||||
)
|
||||
|
||||
# Copy the params
|
||||
with open(torch_path / "params.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
if "sliding_window" in config:
|
||||
config.pop("sliding_window")
|
||||
if "n_kv_heads" not in config:
|
||||
config["n_kv_heads"] = n_heads
|
||||
if "head_dim" not in config:
|
||||
config["head_dim"] = config["dim"] // n_heads
|
||||
if "hidden_dim" not in config:
|
||||
config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape
|
||||
with open(mlx_path / "params.json", "w") as outfile:
|
||||
json.dump(config, outfile)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user