From 77301a2973cf9ec2bb8bc100c4ca04c810820ef9 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 9 Oct 2024 16:55:52 -0700 Subject: [PATCH] nits --- llms/mlx_lm/models/deepseek_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 17d061a8..bb3e5184 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -220,17 +220,17 @@ class DeepseekV2Attention(nn.Module): k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) - k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1) - if cache is not None: q_pe = self.rope(q_pe, cache.offset) k_pe = self.rope(k_pe, cache.offset) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) keys, values = cache.update_and_fetch( mx.concatenate([k_nope, k_pe], axis=-1), values ) else: q_pe = self.rope(q_pe) k_pe = self.rope(k_pe) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) keys = mx.concatenate([k_nope, k_pe], axis=-1) queries = mx.concatenate([q_nope, q_pe], axis=-1) @@ -291,7 +291,7 @@ class MoEGate(nn.Module): scores = scores.reshape(bsz, seq_len, -1) k = self.top_k - inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]) + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] scores = mx.take_along_axis(scores, inds, axis=-1) scores = scores * self.routed_scaling_factor