diff --git a/llama/README.md b/llama/README.md index 1314ca86..4ac4aa84 100644 --- a/llama/README.md +++ b/llama/README.md @@ -29,7 +29,7 @@ Once you've converted the weights to MLX format, you can interact with the LLaMA model: ``` -python llama.py mlx_llama.npz tokenizer.model "hello" +python llama.py mlx_llama_weights.npz "hello" ``` Run `python llama.py --help` for more details. diff --git a/llama/convert.py b/llama/convert.py index 153e056f..9e30a8d4 100644 --- a/llama/convert.py +++ b/llama/convert.py @@ -12,7 +12,8 @@ def map_torch_to_mlx(key, value): key = "embedding.weight" elif "norm" in key: - key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2") + key = key.replace("attention_norm", "norm1").replace( + "ffn_norm", "norm2") elif "wq" in key or "wk" in key or "wv" in key or "wo" in key: key = key.replace("wq", "query_proj") @@ -32,11 +33,12 @@ def map_torch_to_mlx(key, value): elif "rope" in key: return None, None - return key, value.numpy() + return key, value.numpy() if value.dtype != torch.bfloat16 else value.to(torch.float32).numpy() if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") + parser = argparse.ArgumentParser( + description="Convert Llama weights to MLX") parser.add_argument("torch_weights") parser.add_argument("output_file") args = parser.parse_args()