mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Fix lora for qwen moe (#743)
* fix lora for qwen moe * use max seq length in test as well
This commit is contained in:
parent
5079af62db
commit
92430df0a0
@ -238,6 +238,7 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
num_batches=args.test_batches,
|
num_batches=args.test_batches,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
test_ppl = math.exp(test_loss)
|
test_ppl = math.exp(test_loss)
|
||||||
|
@ -141,7 +141,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
inds = np.array(inds)
|
inds = np.array(inds)
|
||||||
y = mx.zeros((B, ne, D), x.dtype)
|
y = mx.zeros((B * L, ne, D), x.dtype)
|
||||||
for e, expert in enumerate(self.experts):
|
for e, expert in enumerate(self.experts):
|
||||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||||
if idx1.size == 0:
|
if idx1.size == 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user