From 6e6ba07b54eafc147e6f99e929b3aef2e59078ee Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 3 Feb 2025 16:59:50 -0800 Subject: [PATCH] fix deepseek sharding (#1242) --- llms/mlx_lm/models/deepseek_v2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 3581fcbe..f22b2e3f 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -282,12 +282,12 @@ class MoEGate(nn.Module): if self.topk_method == "group_limited_greedy": bsz, seq_len = x.shape[:2] scores = scores.reshape(bsz, seq_len, self.n_group, -1) - group_scores = scores.max(axis=-1) + group_scores = scores.max(axis=-1, keepdims=True) k = self.n_group - self.topk_group - group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k] - batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2)) - seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2)) - scores[batch_idx, seq_idx, group_idx] = 0.0 + group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :] + scores = mx.put_along_axis( + scores, group_idx, mx.array(0.0, scores.dtype), axis=-2 + ) scores = scores.reshape(bsz, seq_len, -1) k = self.top_k