fix lora for openelm (#773)

This commit is contained in:
Awni Hannun 2024-05-10 09:51:41 -07:00 committed by GitHub
parent fad9598372
commit 6f0a69e682
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 2 deletions

View File

@ -64,7 +64,7 @@ class TrainingArgs:
def default_loss(model, inputs, targets, lengths): def default_loss(model, inputs, targets, lengths):
logits, _ = model(inputs) logits = model(inputs)
logits = logits.astype(mx.float32) logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]

View File

@ -87,7 +87,9 @@ def linear_to_lora_layers(
keys.add("mlp.shared_expert_gate") keys.add("mlp.shared_expert_gate")
elif model.model_type == "olmo": elif model.model_type == "olmo":
keys = set(["att_proj"]) keys = set(["att_proj"])
elif model.model_type in ["phi3", "openelm"]: elif model.model_type == "openelm":
keys = set(["attn.qkv_proj"])
elif model.model_type == "phi3":
keys = set(["self_attn.qkv_proj"]) keys = set(["self_attn.qkv_proj"])
elif model.model_type == "phi-msft": elif model.model_type == "phi-msft":
keys = set(["mixer.Wqkv", "moe.gate"]) keys = set(["mixer.Wqkv", "moe.gate"])