mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Faster DSv2/3 expert score computation (#1257)
* fix deepseek sharding (#1242) * compile and use put along axis in deep seek routing function
This commit is contained in:
parent
52c41b5b5a
commit
6120a5f376
@ -271,6 +271,38 @@ class DeepseekV3MLP(nn.Module):
|
||||
return down_proj
|
||||
|
||||
|
||||
@mx.compile
|
||||
def group_expert_select(
|
||||
gates,
|
||||
e_score_correction_bias,
|
||||
top_k,
|
||||
n_group,
|
||||
topk_group,
|
||||
routed_scaling_factor,
|
||||
norm_topk_prob,
|
||||
):
|
||||
|
||||
k = top_k
|
||||
scores = mx.sigmoid(gates.astype(mx.float32))
|
||||
scores = scores + e_score_correction_bias
|
||||
scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1))
|
||||
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True)
|
||||
k = n_group - topk_group
|
||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
|
||||
scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2)
|
||||
scores = mx.flatten(scores, -2, -1)
|
||||
|
||||
k = top_k
|
||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
if top_k > 1 and norm_topk_prob:
|
||||
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
|
||||
scores = scores / denominator
|
||||
scores = scores * routed_scaling_factor
|
||||
|
||||
return inds, scores
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
@ -279,38 +311,22 @@ class MoEGate(nn.Module):
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.topk_method = config.topk_method
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
||||
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
|
||||
assert config.topk_method == "noaux_tc", "Unsupported topk method."
|
||||
|
||||
def __call__(self, x):
|
||||
gates = x @ self.weight.T
|
||||
|
||||
scores = mx.sigmoid(gates.astype(mx.float32))
|
||||
|
||||
assert self.topk_method == "noaux_tc", "Unsupported topk method."
|
||||
bsz, seq_len = x.shape[:2]
|
||||
scores = scores + self.e_score_correction_bias
|
||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
||||
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1)
|
||||
k = self.n_group - self.topk_group
|
||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
||||
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
||||
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
||||
scores[batch_idx, seq_idx, group_idx] = 0.0
|
||||
scores = scores.reshape(bsz, seq_len, -1)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
|
||||
scores = scores / denominator
|
||||
scores = scores * self.routed_scaling_factor
|
||||
|
||||
return inds, scores
|
||||
return group_expert_select(
|
||||
x @ self.weight.T,
|
||||
self.e_score_correction_bias,
|
||||
self.top_k,
|
||||
self.n_group,
|
||||
self.topk_group,
|
||||
self.routed_scaling_factor,
|
||||
self.norm_topk_prob,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
|
Loading…
Reference in New Issue
Block a user