This commit is contained in:
Awni Hannun 2024-10-09 16:55:52 -07:00
parent c9e9c75c66
commit 77301a2973

View File

@ -220,17 +220,17 @@ class DeepseekV2Attention(nn.Module):
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1)
if cache is not None:
q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values
)
else:
q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
@ -291,7 +291,7 @@ class MoEGate(nn.Module):
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
scores = scores * self.routed_scaling_factor