2023-12-08 18:14:11 +08:00
|
|
|
import argparse
|
2023-12-21 02:22:25 +08:00
|
|
|
|
2023-12-08 18:14:11 +08:00
|
|
|
import numpy
|
2024-03-20 08:21:33 +08:00
|
|
|
from transformers import AutoModel
|
2023-12-08 18:14:11 +08:00
|
|
|
|
|
|
|
|
|
|
|
def replace_key(key: str) -> str:
|
|
|
|
key = key.replace(".layer.", ".layers.")
|
|
|
|
key = key.replace(".self.key.", ".key_proj.")
|
|
|
|
key = key.replace(".self.query.", ".query_proj.")
|
|
|
|
key = key.replace(".self.value.", ".value_proj.")
|
|
|
|
key = key.replace(".attention.output.dense.", ".attention.out_proj.")
|
|
|
|
key = key.replace(".attention.output.LayerNorm.", ".ln1.")
|
|
|
|
key = key.replace(".output.LayerNorm.", ".ln2.")
|
|
|
|
key = key.replace(".intermediate.dense.", ".linear1.")
|
|
|
|
key = key.replace(".output.dense.", ".linear2.")
|
|
|
|
key = key.replace(".LayerNorm.", ".norm.")
|
|
|
|
key = key.replace("pooler.dense.", "pooler.")
|
|
|
|
return key
|
|
|
|
|
|
|
|
|
|
|
|
def convert(bert_model: str, mlx_model: str) -> None:
|
2024-03-20 08:21:33 +08:00
|
|
|
model = AutoModel.from_pretrained(bert_model)
|
2023-12-08 18:14:11 +08:00
|
|
|
# save the tensors
|
|
|
|
tensors = {
|
|
|
|
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
|
|
|
|
}
|
|
|
|
numpy.savez(mlx_model, **tensors)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--bert-model",
|
2024-03-20 08:21:33 +08:00
|
|
|
type=str,
|
2023-12-08 18:14:11 +08:00
|
|
|
default="bert-base-uncased",
|
2024-03-20 08:21:33 +08:00
|
|
|
help="The huggingface name of the BERT model to save. Any BERT-like model can be specified.",
|
2023-12-08 18:14:11 +08:00
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--mlx-model",
|
|
|
|
type=str,
|
|
|
|
default="weights/bert-base-uncased.npz",
|
|
|
|
help="The output path for the MLX BERT weights.",
|
|
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
2023-12-10 06:15:25 +08:00
|
|
|
convert(args.bert_model, args.mlx_model)
|