mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 03:05:20 +08:00
Merge pull request #40 from Jacksonzhang0316/main
fix: Unsupported BFloat16 Data Type Issue with MPS Backend
This commit is contained in:
commit
6259c9a048
@ -32,7 +32,12 @@ def map_torch_to_mlx(key, value):
|
|||||||
elif "rope" in key:
|
elif "rope" in key:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
return key, value.numpy()
|
return (
|
||||||
|
key,
|
||||||
|
value.numpy()
|
||||||
|
if value.dtype != torch.bfloat16
|
||||||
|
else value.to(torch.float32).numpy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user