From 71aff8c346a7afe18ab321b1aeefcf5272799d98 Mon Sep 17 00:00:00 2001 From: ricardo-larosa Date: Wed, 6 Dec 2023 13:30:59 +0100 Subject: [PATCH] Fix unsupported ScalarType BFloat16 --- llama/convert.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/llama/convert.py b/llama/convert.py index 9e30a8d4..2ce247a3 100644 --- a/llama/convert.py +++ b/llama/convert.py @@ -12,8 +12,7 @@ def map_torch_to_mlx(key, value): key = "embedding.weight" elif "norm" in key: - key = key.replace("attention_norm", "norm1").replace( - "ffn_norm", "norm2") + key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2") elif "wq" in key or "wk" in key or "wv" in key or "wo" in key: key = key.replace("wq", "query_proj") @@ -33,12 +32,16 @@ def map_torch_to_mlx(key, value): elif "rope" in key: return None, None - return key, value.numpy() if value.dtype != torch.bfloat16 else value.to(torch.float32).numpy() + return ( + key, + value.numpy() + if value.dtype != torch.bfloat16 + else value.to(torch.float32).numpy(), + ) if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Convert Llama weights to MLX") + parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") parser.add_argument("torch_weights") parser.add_argument("output_file") args = parser.parse_args()