diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py index 69ee1be0..2df93d9f 100644 --- a/llms/mlx_lm/models/deepseek_v3.py +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -271,6 +271,38 @@ class DeepseekV3MLP(nn.Module): return down_proj +@mx.compile +def group_expert_select( + gates, + e_score_correction_bias, + top_k, + n_group, + topk_group, + routed_scaling_factor, + norm_topk_prob, +): + + k = top_k + scores = mx.sigmoid(gates.astype(mx.float32)) + scores = scores + e_score_correction_bias + scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1)) + group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True) + k = n_group - topk_group + group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :] + scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2) + scores = mx.flatten(scores, -2, -1) + + k = top_k + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] + scores = mx.take_along_axis(scores, inds, axis=-1) + if top_k > 1 and norm_topk_prob: + denominator = scores.sum(axis=-1, keepdims=True) + 1e-20 + scores = scores / denominator + scores = scores * routed_scaling_factor + + return inds, scores + + class MoEGate(nn.Module): def __init__(self, config: ModelArgs): super().__init__() @@ -279,38 +311,22 @@ class MoEGate(nn.Module): self.norm_topk_prob = config.norm_topk_prob self.n_routed_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor - self.topk_method = config.topk_method self.n_group = config.n_group self.topk_group = config.topk_group self.weight = mx.zeros((self.n_routed_experts, config.hidden_size)) self.e_score_correction_bias = mx.zeros((self.n_routed_experts,)) + assert config.topk_method == "noaux_tc", "Unsupported topk method." def __call__(self, x): - gates = x @ self.weight.T - - scores = mx.sigmoid(gates.astype(mx.float32)) - - assert self.topk_method == "noaux_tc", "Unsupported topk method." - bsz, seq_len = x.shape[:2] - scores = scores + self.e_score_correction_bias - scores = scores.reshape(bsz, seq_len, self.n_group, -1) - group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1) - k = self.n_group - self.topk_group - group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k] - batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2)) - seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2)) - scores[batch_idx, seq_idx, group_idx] = 0.0 - scores = scores.reshape(bsz, seq_len, -1) - - k = self.top_k - inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] - scores = mx.take_along_axis(scores, inds, axis=-1) - if self.top_k > 1 and self.norm_topk_prob: - denominator = scores.sum(axis=-1, keepdims=True) + 1e-20 - scores = scores / denominator - scores = scores * self.routed_scaling_factor - - return inds, scores + return group_expert_select( + x @ self.weight.T, + self.e_score_correction_bias, + self.top_k, + self.n_group, + self.topk_group, + self.routed_scaling_factor, + self.norm_topk_prob, + ) class DeepseekV3MoE(nn.Module):