Fix lora for qwen moe (#743)

* fix lora for qwen moe

* use max seq length in test as well
This commit is contained in:
Awni Hannun 2024-05-02 21:55:09 -07:00 committed by GitHub
parent 5079af62db
commit 92430df0a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 1 deletions

View File

@ -238,6 +238,7 @@ def run(args, training_callback: TrainingCallback = None):
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
test_ppl = math.exp(test_loss)

View File

@ -141,7 +141,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
if self.training:
inds = np.array(inds)
y = mx.zeros((B, ne, D), x.dtype)
y = mx.zeros((B * L, ne, D), x.dtype)
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0: