From 4018aed335e6ce21e171fea718169455afe81a18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=98=89=E8=B1=AA?= Date: Fri, 8 Dec 2023 16:19:35 +0800 Subject: [PATCH] fix: Unsupported BFloat16 Data Type Issue with MPS Backend --- lora/convert.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lora/convert.py b/lora/convert.py index 153e056f..2ce247a3 100644 --- a/lora/convert.py +++ b/lora/convert.py @@ -32,7 +32,12 @@ def map_torch_to_mlx(key, value): elif "rope" in key: return None, None - return key, value.numpy() + return ( + key, + value.numpy() + if value.dtype != torch.bfloat16 + else value.to(torch.float32).numpy(), + ) if __name__ == "__main__":