mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Fix unsupported ScalarType BFloat16
This commit is contained in:
parent
0bf5d0e3bc
commit
429ddb30dc
@ -29,7 +29,7 @@ Once you've converted the weights to MLX format, you can interact with the
|
||||
LLaMA model:
|
||||
|
||||
```
|
||||
python llama.py mlx_llama.npz tokenizer.model "hello"
|
||||
python llama.py mlx_llama_weights.npz <path_to_tokenizer.model> "hello"
|
||||
```
|
||||
|
||||
Run `python llama.py --help` for more details.
|
||||
|
@ -12,7 +12,8 @@ def map_torch_to_mlx(key, value):
|
||||
key = "embedding.weight"
|
||||
|
||||
elif "norm" in key:
|
||||
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
|
||||
key = key.replace("attention_norm", "norm1").replace(
|
||||
"ffn_norm", "norm2")
|
||||
|
||||
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
|
||||
key = key.replace("wq", "query_proj")
|
||||
@ -32,11 +33,12 @@ def map_torch_to_mlx(key, value):
|
||||
elif "rope" in key:
|
||||
return None, None
|
||||
|
||||
return key, value.numpy()
|
||||
return key, value.numpy() if value.dtype != torch.bfloat16 else value.to(torch.float32).numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Llama weights to MLX")
|
||||
parser.add_argument("torch_weights")
|
||||
parser.add_argument("output_file")
|
||||
args = parser.parse_args()
|
||||
|
Loading…
Reference in New Issue
Block a user