Fix prompt cache for models without chat template (#1250)

* fix deepseek sharding (#1242)

* fix prompt cache with no chat template
This commit is contained in:
Awni Hannun
2025-02-06 11:10:58 -08:00
committed by GitHub
parent 747c08e202
commit 52c41b5b5a
3 changed files with 7 additions and 7 deletions

View File

@@ -282,12 +282,12 @@ class MoEGate(nn.Module):
if self.topk_method == "group_limited_greedy":
bsz, seq_len = x.shape[:2]
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
group_scores = scores.max(axis=-1)
group_scores = scores.max(axis=-1, keepdims=True)
k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
scores[batch_idx, seq_idx, group_idx] = 0.0
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
scores = mx.put_along_axis(
scores, group_idx, mx.array(0.0, scores.dtype), axis=-2
)
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k