diff --git a/llama/README.md b/llama/README.md index 63aba7d9..3a3e9e89 100644 --- a/llama/README.md +++ b/llama/README.md @@ -32,7 +32,7 @@ Once you've converted the weights to MLX format, you can interact with the LLaMA model: ``` -python llama.py mlx_llama.npz tokenizer.model "hello" +python llama.py mlx_llama_weights.npz "hello" ``` Run `python llama.py --help` for more details. diff --git a/llama/convert.py b/llama/convert.py index 153e056f..2ce247a3 100644 --- a/llama/convert.py +++ b/llama/convert.py @@ -32,7 +32,12 @@ def map_torch_to_mlx(key, value): elif "rope" in key: return None, None - return key, value.numpy() + return ( + key, + value.numpy() + if value.dtype != torch.bfloat16 + else value.to(torch.float32).numpy(), + ) if __name__ == "__main__":