Fix unsupported ScalarType BFloat16

This commit is contained in:
ricardo-larosa 2023-12-06 13:30:59 +01:00
parent 0bf5d0e3bc
commit 429ddb30dc
2 changed files with 6 additions and 4 deletions

View File

@ -29,7 +29,7 @@ Once you've converted the weights to MLX format, you can interact with the
LLaMA model: LLaMA model:
``` ```
python llama.py mlx_llama.npz tokenizer.model "hello" python llama.py mlx_llama_weights.npz <path_to_tokenizer.model> "hello"
``` ```
Run `python llama.py --help` for more details. Run `python llama.py --help` for more details.

View File

@ -12,7 +12,8 @@ def map_torch_to_mlx(key, value):
key = "embedding.weight" key = "embedding.weight"
elif "norm" in key: elif "norm" in key:
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2") key = key.replace("attention_norm", "norm1").replace(
"ffn_norm", "norm2")
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key: elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
key = key.replace("wq", "query_proj") key = key.replace("wq", "query_proj")
@ -32,11 +33,12 @@ def map_torch_to_mlx(key, value):
elif "rope" in key: elif "rope" in key:
return None, None return None, None
return key, value.numpy() return key, value.numpy() if value.dtype != torch.bfloat16 else value.to(torch.float32).numpy()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") parser = argparse.ArgumentParser(
description="Convert Llama weights to MLX")
parser.add_argument("torch_weights") parser.add_argument("torch_weights")
parser.add_argument("output_file") parser.add_argument("output_file")
args = parser.parse_args() args = parser.parse_args()