From 738448c2d4d55585cbb178c8b54f55664a8cbe89 Mon Sep 17 00:00:00 2001 From: Yifan Date: Mon, 25 Dec 2023 22:10:01 +0800 Subject: [PATCH] QWEN: Fix unsupported ScalarType BFloat16 (#187) Fix unsupported ScalarType BFloat16. --- llms/qwen/convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/qwen/convert.py b/llms/qwen/convert.py index 88135208..e91be263 100644 --- a/llms/qwen/convert.py +++ b/llms/qwen/convert.py @@ -60,7 +60,7 @@ def convert(args): args.model, trust_remote_code=True, torch_dtype=torch.float16 ) state_dict = model.state_dict() - weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} + weights = {replace_key(k): (v.numpy() if v.dtype != torch.bfloat16 else v.to(torch.float32).numpy()) for k, v in state_dict.items()} config = model.config.to_dict() if args.quantize: