Merge pull request #40 from Jacksonzhang0316/main

fix: Unsupported BFloat16 Data Type Issue with MPS Backend
This commit is contained in:
Awni Hannun 2023-12-08 06:48:21 -08:00 committed by GitHub
commit 6259c9a048
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -32,7 +32,12 @@ def map_torch_to_mlx(key, value):
elif "rope" in key:
return None, None
return key, value.numpy()
return (
key,
value.numpy()
if value.dtype != torch.bfloat16
else value.to(torch.float32).numpy(),
)
if __name__ == "__main__":