mlx-examples/mistral/convert.py

28 lines
734 B
Python
Raw Normal View History

2023-12-06 03:02:52 +08:00
# Copyright © 2023 Apple Inc.
import argparse
import numpy as np
import torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
parser.add_argument(
"--torch_model",
type=str,
default="mistral-7B-v0.1/consolidated.00.pth",
help="The path to the torch model weights",
)
parser.add_argument(
"--mlx_model",
type=str,
default="mistral-7B-v0.1/mlx_mistral_7b.npz",
help="The path to store the mlx model weights",
)
args = parser.parse_args()
state = torch.load(args.torch_model)
np.savez(
args.mlx_model, **{k: v.to(torch.float16).numpy() for k, v in state.items()}
)