mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-09 18:36:38 +08:00
Fix unsupported ScalarType BFloat16
This commit is contained in:
parent
429ddb30dc
commit
71aff8c346
@ -12,8 +12,7 @@ def map_torch_to_mlx(key, value):
|
|||||||
key = "embedding.weight"
|
key = "embedding.weight"
|
||||||
|
|
||||||
elif "norm" in key:
|
elif "norm" in key:
|
||||||
key = key.replace("attention_norm", "norm1").replace(
|
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
|
||||||
"ffn_norm", "norm2")
|
|
||||||
|
|
||||||
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
|
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
|
||||||
key = key.replace("wq", "query_proj")
|
key = key.replace("wq", "query_proj")
|
||||||
@ -33,12 +32,16 @@ def map_torch_to_mlx(key, value):
|
|||||||
elif "rope" in key:
|
elif "rope" in key:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
return key, value.numpy() if value.dtype != torch.bfloat16 else value.to(torch.float32).numpy()
|
return (
|
||||||
|
key,
|
||||||
|
value.numpy()
|
||||||
|
if value.dtype != torch.bfloat16
|
||||||
|
else value.to(torch.float32).numpy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||||
description="Convert Llama weights to MLX")
|
|
||||||
parser.add_argument("torch_weights")
|
parser.add_argument("torch_weights")
|
||||||
parser.add_argument("output_file")
|
parser.add_argument("output_file")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
Loading…
Reference in New Issue
Block a user