Faster DSv2/3 expert score computation (#1257)

* fix deepseek sharding (#1242)

* compile and use put along axis in deep seek routing function
This commit is contained in:
Awni Hannun 2025-02-07 10:24:57 -08:00 committed by GitHub
parent 52c41b5b5a
commit 6120a5f376
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -271,6 +271,38 @@ class DeepseekV3MLP(nn.Module):
return down_proj
@mx.compile
def group_expert_select(
gates,
e_score_correction_bias,
top_k,
n_group,
topk_group,
routed_scaling_factor,
norm_topk_prob,
):
k = top_k
scores = mx.sigmoid(gates.astype(mx.float32))
scores = scores + e_score_correction_bias
scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1))
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True)
k = n_group - topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2)
scores = mx.flatten(scores, -2, -1)
k = top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
if top_k > 1 and norm_topk_prob:
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
scores = scores / denominator
scores = scores * routed_scaling_factor
return inds, scores
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
@ -279,38 +311,22 @@ class MoEGate(nn.Module):
self.norm_topk_prob = config.norm_topk_prob
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.topk_method = config.topk_method
self.n_group = config.n_group
self.topk_group = config.topk_group
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
assert config.topk_method == "noaux_tc", "Unsupported topk method."
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.sigmoid(gates.astype(mx.float32))
assert self.topk_method == "noaux_tc", "Unsupported topk method."
bsz, seq_len = x.shape[:2]
scores = scores + self.e_score_correction_bias
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1)
k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
scores[batch_idx, seq_idx, group_idx] = 0.0
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
if self.top_k > 1 and self.norm_topk_prob:
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
scores = scores / denominator
scores = scores * self.routed_scaling_factor
return inds, scores
return group_expert_select(
x @ self.weight.T,
self.e_score_correction_bias,
self.top_k,
self.n_group,
self.topk_group,
self.routed_scaling_factor,
self.norm_topk_prob,
)
class DeepseekV3MoE(nn.Module):