From b9607f9510889fe5c5c166061203d33259aa1098 Mon Sep 17 00:00:00 2001 From: Yifan Date: Mon, 25 Dec 2023 20:18:26 +0800 Subject: [PATCH] QWEN: Fix unsupported ScalarType BFloat16 Fix unsupported ScalarType BFloat16. Env: Mac M1 Ultra torch: torch-2.0.0, metal Apple clang version 15.0.0 (clang-1500.0.40.1) Target: arm64-apple-darwin23.1.0 Thread model: posix InstalledDir: /Library/Developer/CommandLineTools/usr/bin ``` Traceback (most recent call last): File "/Volumes/v1/models/mlx-examples-main/llms/qwen/convert.py", line 110, in convert(args) File "/Volumes/v1/models/mlx-examples-main/llms/qwen/convert.py", line 63, in convert weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} File "/Volumes/v1/models/mlx-examples-main/llms/qwen/convert.py", line 63, in weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} TypeError: Got unsupported ScalarType BFloat16 ``` Fix: almost same as [#10](https://github.com/ml-explore/mlx-examples/pull/10/commits/429ddb30dca199c9cfbfe2280cf47875cc3f0be9) --- llms/qwen/convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/qwen/convert.py b/llms/qwen/convert.py index 88135208..e91be263 100644 --- a/llms/qwen/convert.py +++ b/llms/qwen/convert.py @@ -60,7 +60,7 @@ def convert(args): args.model, trust_remote_code=True, torch_dtype=torch.float16 ) state_dict = model.state_dict() - weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} + weights = {replace_key(k): (v.numpy() if v.dtype != torch.bfloat16 else v.to(torch.float32).numpy()) for k, v in state_dict.items()} config = model.config.to_dict() if args.quantize: