mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Faster DSv2/3 expert score computation (#1257)
* fix deepseek sharding (#1242) * compile and use put along axis in deep seek routing function
This commit is contained in:
parent
52c41b5b5a
commit
6120a5f376
@ -271,6 +271,38 @@ class DeepseekV3MLP(nn.Module):
|
|||||||
return down_proj
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def group_expert_select(
|
||||||
|
gates,
|
||||||
|
e_score_correction_bias,
|
||||||
|
top_k,
|
||||||
|
n_group,
|
||||||
|
topk_group,
|
||||||
|
routed_scaling_factor,
|
||||||
|
norm_topk_prob,
|
||||||
|
):
|
||||||
|
|
||||||
|
k = top_k
|
||||||
|
scores = mx.sigmoid(gates.astype(mx.float32))
|
||||||
|
scores = scores + e_score_correction_bias
|
||||||
|
scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1))
|
||||||
|
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True)
|
||||||
|
k = n_group - topk_group
|
||||||
|
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
|
||||||
|
scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2)
|
||||||
|
scores = mx.flatten(scores, -2, -1)
|
||||||
|
|
||||||
|
k = top_k
|
||||||
|
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||||
|
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||||
|
if top_k > 1 and norm_topk_prob:
|
||||||
|
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
|
||||||
|
scores = scores / denominator
|
||||||
|
scores = scores * routed_scaling_factor
|
||||||
|
|
||||||
|
return inds, scores
|
||||||
|
|
||||||
|
|
||||||
class MoEGate(nn.Module):
|
class MoEGate(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -279,38 +311,22 @@ class MoEGate(nn.Module):
|
|||||||
self.norm_topk_prob = config.norm_topk_prob
|
self.norm_topk_prob = config.norm_topk_prob
|
||||||
self.n_routed_experts = config.n_routed_experts
|
self.n_routed_experts = config.n_routed_experts
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
self.topk_method = config.topk_method
|
|
||||||
self.n_group = config.n_group
|
self.n_group = config.n_group
|
||||||
self.topk_group = config.topk_group
|
self.topk_group = config.topk_group
|
||||||
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
||||||
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
|
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
|
||||||
|
assert config.topk_method == "noaux_tc", "Unsupported topk method."
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
gates = x @ self.weight.T
|
return group_expert_select(
|
||||||
|
x @ self.weight.T,
|
||||||
scores = mx.sigmoid(gates.astype(mx.float32))
|
self.e_score_correction_bias,
|
||||||
|
self.top_k,
|
||||||
assert self.topk_method == "noaux_tc", "Unsupported topk method."
|
self.n_group,
|
||||||
bsz, seq_len = x.shape[:2]
|
self.topk_group,
|
||||||
scores = scores + self.e_score_correction_bias
|
self.routed_scaling_factor,
|
||||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
self.norm_topk_prob,
|
||||||
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1)
|
)
|
||||||
k = self.n_group - self.topk_group
|
|
||||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
|
||||||
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
|
||||||
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
|
||||||
scores[batch_idx, seq_idx, group_idx] = 0.0
|
|
||||||
scores = scores.reshape(bsz, seq_len, -1)
|
|
||||||
|
|
||||||
k = self.top_k
|
|
||||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
|
||||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
|
||||||
if self.top_k > 1 and self.norm_topk_prob:
|
|
||||||
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
|
|
||||||
scores = scores / denominator
|
|
||||||
scores = scores * self.routed_scaling_factor
|
|
||||||
|
|
||||||
return inds, scores
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3MoE(nn.Module):
|
class DeepseekV3MoE(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user