Fix lora for qwen moe (#743)

* fix lora for qwen moe

* use max seq length in test as well
This commit is contained in:
Awni Hannun
2024-05-02 21:55:09 -07:00
committed by GitHub
parent 5079af62db
commit 92430df0a0
2 changed files with 2 additions and 1 deletions

View File

@@ -141,7 +141,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
if self.training:
inds = np.array(inds)
y = mx.zeros((B, ne, D), x.dtype)
y = mx.zeros((B * L, ne, D), x.dtype)
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0: