mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Fix lora for qwen moe (#743)
* fix lora for qwen moe * use max seq length in test as well
This commit is contained in:
parent
5079af62db
commit
92430df0a0
@ -238,6 +238,7 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
)
|
||||
|
||||
test_ppl = math.exp(test_loss)
|
||||
|
@ -141,7 +141,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
|
||||
if self.training:
|
||||
inds = np.array(inds)
|
||||
y = mx.zeros((B, ne, D), x.dtype)
|
||||
y = mx.zeros((B * L, ne, D), x.dtype)
|
||||
for e, expert in enumerate(self.experts):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
|
Loading…
Reference in New Issue
Block a user