From 6f0a69e682e575458b2966fae2fdba6a20aa5f8e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 10 May 2024 09:51:41 -0700 Subject: [PATCH] fix lora for openelm (#773) --- llms/mlx_lm/tuner/trainer.py | 2 +- llms/mlx_lm/tuner/utils.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index e2b55db3..f5957782 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -64,7 +64,7 @@ class TrainingArgs: def default_loss(model, inputs, targets, lengths): - logits, _ = model(inputs) + logits = model(inputs) logits = logits.astype(mx.float32) length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index c522a8c3..0b529366 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -87,7 +87,9 @@ def linear_to_lora_layers( keys.add("mlp.shared_expert_gate") elif model.model_type == "olmo": keys = set(["att_proj"]) - elif model.model_type in ["phi3", "openelm"]: + elif model.model_type == "openelm": + keys = set(["attn.qkv_proj"]) + elif model.model_type == "phi3": keys = set(["self_attn.qkv_proj"]) elif model.model_type == "phi-msft": keys = set(["mixer.Wqkv", "moe.gate"])