Merge pull request #10 from ricardo-larosa/fix-unsupported-scalartype

Fix unsupported ScalarType BFloat16
This commit is contained in:
Awni Hannun 2023-12-07 08:05:44 -08:00 committed by GitHub
commit 62dcb3301f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 2 deletions

View File

@ -32,7 +32,7 @@ Once you've converted the weights to MLX format, you can interact with the
LLaMA model: LLaMA model:
``` ```
python llama.py mlx_llama.npz tokenizer.model "hello" python llama.py mlx_llama_weights.npz <path_to_tokenizer.model> "hello"
``` ```
Run `python llama.py --help` for more details. Run `python llama.py --help` for more details.

View File

@ -32,7 +32,12 @@ def map_torch_to_mlx(key, value):
elif "rope" in key: elif "rope" in key:
return None, None return None, None
return key, value.numpy() return (
key,
value.numpy()
if value.dtype != torch.bfloat16
else value.to(torch.float32).numpy(),
)
if __name__ == "__main__": if __name__ == "__main__":