mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 00:30:09 +08:00
fix deepseek sharding (#1242)
This commit is contained in:
parent
e2e5478da5
commit
6e6ba07b54
@ -282,12 +282,12 @@ class MoEGate(nn.Module):
|
||||
if self.topk_method == "group_limited_greedy":
|
||||
bsz, seq_len = x.shape[:2]
|
||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
||||
group_scores = scores.max(axis=-1)
|
||||
group_scores = scores.max(axis=-1, keepdims=True)
|
||||
k = self.n_group - self.topk_group
|
||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
||||
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
||||
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
||||
scores[batch_idx, seq_idx, group_idx] = 0.0
|
||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
|
||||
scores = mx.put_along_axis(
|
||||
scores, group_idx, mx.array(0.0, scores.dtype), axis=-2
|
||||
)
|
||||
scores = scores.reshape(bsz, seq_len, -1)
|
||||
|
||||
k = self.top_k
|
||||
|
Loading…
Reference in New Issue
Block a user