Merge pull request #10 from ricardo-larosa/fix-unsupported-scalartype

Fix unsupported ScalarType BFloat16
This commit is contained in:
Awni Hannun 2023-12-07 08:05:44 -08:00 committed by GitHub
commit 62dcb3301f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 2 deletions

View File

@ -32,7 +32,7 @@ Once you've converted the weights to MLX format, you can interact with the
LLaMA model:
```
python llama.py mlx_llama.npz tokenizer.model "hello"
python llama.py mlx_llama_weights.npz <path_to_tokenizer.model> "hello"
```
Run `python llama.py --help` for more details.

View File

@ -32,7 +32,12 @@ def map_torch_to_mlx(key, value):
elif "rope" in key:
return None, None
return key, value.numpy()
return (
key,
value.numpy()
if value.dtype != torch.bfloat16
else value.to(torch.float32).numpy(),
)
if __name__ == "__main__":