mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Fix prompt cache for models without chat template (#1250)
* fix deepseek sharding (#1242) * fix prompt cache with no chat template
This commit is contained in:
parent
747c08e202
commit
52c41b5b5a
@ -152,7 +152,7 @@ def main():
|
|||||||
print("Saving...")
|
print("Saving...")
|
||||||
metadata = {}
|
metadata = {}
|
||||||
metadata["model"] = args.model
|
metadata["model"] = args.model
|
||||||
metadata["chat_template"] = tokenizer.chat_template
|
metadata["chat_template"] = json.dumps(tokenizer.chat_template)
|
||||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||||
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
||||||
|
|
||||||
|
@ -199,7 +199,7 @@ def main():
|
|||||||
if tokenizer.chat_template is None:
|
if tokenizer.chat_template is None:
|
||||||
tokenizer.chat_template = tokenizer.default_chat_template
|
tokenizer.chat_template = tokenizer.default_chat_template
|
||||||
elif using_cache:
|
elif using_cache:
|
||||||
tokenizer.chat_template = metadata["chat_template"]
|
tokenizer.chat_template = json.loads(metadata["chat_template"])
|
||||||
|
|
||||||
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
|
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
|
||||||
prompt = sys.stdin.read() if prompt == "-" else prompt
|
prompt = sys.stdin.read() if prompt == "-" else prompt
|
||||||
|
@ -282,12 +282,12 @@ class MoEGate(nn.Module):
|
|||||||
if self.topk_method == "group_limited_greedy":
|
if self.topk_method == "group_limited_greedy":
|
||||||
bsz, seq_len = x.shape[:2]
|
bsz, seq_len = x.shape[:2]
|
||||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
||||||
group_scores = scores.max(axis=-1)
|
group_scores = scores.max(axis=-1, keepdims=True)
|
||||||
k = self.n_group - self.topk_group
|
k = self.n_group - self.topk_group
|
||||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
|
||||||
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
scores = mx.put_along_axis(
|
||||||
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
scores, group_idx, mx.array(0.0, scores.dtype), axis=-2
|
||||||
scores[batch_idx, seq_idx, group_idx] = 0.0
|
)
|
||||||
scores = scores.reshape(bsz, seq_len, -1)
|
scores = scores.reshape(bsz, seq_len, -1)
|
||||||
|
|
||||||
k = self.top_k
|
k = self.top_k
|
||||||
|
Loading…
Reference in New Issue
Block a user