This commit is contained in:
Awni Hannun
2024-06-13 07:47:16 -07:00
parent fda41545a6
commit 7c6ced183d
2 changed files with 186 additions and 1 deletions

View File

@@ -122,8 +122,10 @@ def linear_to_lora_layers(
keys = set(["norm_attn_norm.attn.Wqkv", "ffn.router.layer"])
elif model.model_type == "internlm2":
keys = set(["attention.wqkv", "attention.wo"])
elif model.model_type == "openlm":
keys = set(["attention.in_proj", "attention.out_proj"])
else:
raise ValueError(f"Lora does not support {model.model_type}")
raise ValueError(f"LoRA does not support {model.model_type}")
for l in model.layers[num_layers - num_lora_layers :]:
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]