From 92430df0a03355ef2bc07f4a442bd0381780b4b9 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 2 May 2024 21:55:09 -0700 Subject: [PATCH] Fix lora for qwen moe (#743) * fix lora for qwen moe * use max seq length in test as well --- llms/mlx_lm/lora.py | 1 + llms/mlx_lm/models/qwen2_moe.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 333e447c..df382cfe 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -238,6 +238,7 @@ def run(args, training_callback: TrainingCallback = None): tokenizer=tokenizer, batch_size=args.batch_size, num_batches=args.test_batches, + max_seq_length=args.max_seq_length, ) test_ppl = math.exp(test_loss) diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index 536d2e1b..abe9452c 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -141,7 +141,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): if self.training: inds = np.array(inds) - y = mx.zeros((B, ne, D), x.dtype) + y = mx.zeros((B * L, ne, D), x.dtype) for e, expert in enumerate(self.experts): idx1, idx2 = map(mx.array, np.where(inds == e)) if idx1.size == 0: