mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
fix lora for openelm (#773)
This commit is contained in:
parent
fad9598372
commit
6f0a69e682
@ -64,7 +64,7 @@ class TrainingArgs:
|
|||||||
|
|
||||||
|
|
||||||
def default_loss(model, inputs, targets, lengths):
|
def default_loss(model, inputs, targets, lengths):
|
||||||
logits, _ = model(inputs)
|
logits = model(inputs)
|
||||||
logits = logits.astype(mx.float32)
|
logits = logits.astype(mx.float32)
|
||||||
|
|
||||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||||
|
@ -87,7 +87,9 @@ def linear_to_lora_layers(
|
|||||||
keys.add("mlp.shared_expert_gate")
|
keys.add("mlp.shared_expert_gate")
|
||||||
elif model.model_type == "olmo":
|
elif model.model_type == "olmo":
|
||||||
keys = set(["att_proj"])
|
keys = set(["att_proj"])
|
||||||
elif model.model_type in ["phi3", "openelm"]:
|
elif model.model_type == "openelm":
|
||||||
|
keys = set(["attn.qkv_proj"])
|
||||||
|
elif model.model_type == "phi3":
|
||||||
keys = set(["self_attn.qkv_proj"])
|
keys = set(["self_attn.qkv_proj"])
|
||||||
elif model.model_type == "phi-msft":
|
elif model.model_type == "phi-msft":
|
||||||
keys = set(["mixer.Wqkv", "moe.gate"])
|
keys = set(["mixer.Wqkv", "moe.gate"])
|
||||||
|
Loading…
Reference in New Issue
Block a user