From 52c41b5b5abfdd4ee1c35bd362162b1dc7a62138 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 6 Feb 2025 11:10:58 -0800 Subject: [PATCH] Fix prompt cache for models without chat template (#1250) * fix deepseek sharding (#1242) * fix prompt cache with no chat template --- llms/mlx_lm/cache_prompt.py | 2 +- llms/mlx_lm/generate.py | 2 +- llms/mlx_lm/models/deepseek_v2.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index c18f1bae..fff64f78 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -152,7 +152,7 @@ def main(): print("Saving...") metadata = {} metadata["model"] = args.model - metadata["chat_template"] = tokenizer.chat_template + metadata["chat_template"] = json.dumps(tokenizer.chat_template) metadata["tokenizer_config"] = json.dumps(tokenizer_config) save_prompt_cache(args.prompt_cache_file, cache, metadata) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0d286c75..e7994750 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -199,7 +199,7 @@ def main(): if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template elif using_cache: - tokenizer.chat_template = metadata["chat_template"] + tokenizer.chat_template = json.loads(metadata["chat_template"]) prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t") prompt = sys.stdin.read() if prompt == "-" else prompt diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 3581fcbe..f22b2e3f 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -282,12 +282,12 @@ class MoEGate(nn.Module): if self.topk_method == "group_limited_greedy": bsz, seq_len = x.shape[:2] scores = scores.reshape(bsz, seq_len, self.n_group, -1) - group_scores = scores.max(axis=-1) + group_scores = scores.max(axis=-1, keepdims=True) k = self.n_group - self.topk_group - group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k] - batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2)) - seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2)) - scores[batch_idx, seq_idx, group_idx] = 0.0 + group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :] + scores = mx.put_along_axis( + scores, group_idx, mx.array(0.0, scores.dtype), axis=-2 + ) scores = scores.reshape(bsz, seq_len, -1) k = self.top_k