diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 84dc63ca..afb1394e 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -1,7 +1,6 @@ # Copyright © 2023-2024 Apple Inc. import argparse -import codecs import json import sys @@ -189,8 +188,8 @@ def main(): elif using_cache: tokenizer.chat_template = metadata["chat_template"] - prompt = codecs.decode(args.prompt, "unicode_escape") - + prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t") + prompt = sys.stdin.read() if prompt == "-" else prompt if not args.ignore_chat_template and ( hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None @@ -199,12 +198,7 @@ def main(): messages = [{"role": "system", "content": args.system_prompt}] else: messages = [] - messages.append( - { - "role": "user", - "content": sys.stdin.read() if prompt == "-" else prompt, - } - ) + messages.append({"role": "user", "content": prompt}) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index c77f056a..c48a32cf 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -12,6 +12,7 @@ def make_sampler( top_p: float = 0.0, min_p: float = 0.0, min_tokens_to_keep: int = 1, + top_k: int = -1, ) -> Callable[mx.array, mx.array]: """ Make a sampler function for use with ``generate_step``. @@ -25,6 +26,8 @@ def make_sampler( probability) that a token probability must have to be considered. min_tokens_to_keep (int, optional): Minimum number of tokens that cannot be filtered by min_p sampling. + top_k (int, optional): The top k tokens ranked by probability to constrain + the sampling to. Returns: Callable[mx.array, mx.array]: @@ -36,6 +39,8 @@ def make_sampler( return lambda x: top_p_sampling(x, top_p, temp) elif min_p != 0.0: return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp) + elif top_k > 0: + return lambda x: top_k_sampling(x, top_k, temp) else: return lambda x: categorical_sampling(x, temp) @@ -79,6 +84,33 @@ def make_logits_processors( return logits_processors +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def top_k_sampling( + logprobs: mx.array, + top_k: int, + temperature=1.0, +) -> mx.array: + """ + Sample from only the top K tokens ranked by probability. + + Args: + logprobs: A vector of log probabilities. + top_k (int): Top k tokens to sample from. + """ + vocab_size = logprobs.shape[-1] + if not isinstance(top_k, int) or not (0 < top_k < vocab_size): + raise ValueError( + f"`top_k` has to be an integer in the (0, {vocab_size}] interval," + f" but is {top_k}." + ) + logprobs = logprobs * (1 / temperature) + mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:] + masked_logprobs = mx.put_along_axis( + logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1 + ) + return mx.random.categorical(masked_logprobs, axis=-1) + + @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def min_p_sampling( logprobs: mx.array, @@ -87,7 +119,7 @@ def min_p_sampling( temperature=1.0, ) -> mx.array: """ - Apply min-p sampling to the logits. + Apply min-p sampling to the logprobs. Min-p keeps all tokens that are above a minimum probability, scaled by the probability of the most likely token. As a result, the filter is more diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index ebc90ce8..c45fa443 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -1,7 +1,7 @@ import unittest import mlx.core as mx -from mlx_lm.sample_utils import min_p_sampling, top_p_sampling +from mlx_lm.sample_utils import min_p_sampling, top_k_sampling, top_p_sampling class TestSampleUtils(unittest.TestCase): @@ -42,6 +42,27 @@ class TestSampleUtils(unittest.TestCase): token = min_p_sampling(logits, 0.05) self.assertTrue(token in (0, 3)) + def test_top_k_sampling(self): + probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] + logits = mx.log(probs) + + token = top_k_sampling(logits, 1).item() + self.assertEqual(token, 0) + + probs = mx.array([0.5, 0.0, 0.0, 0.5])[None] + tokens = set() + for _ in range(100): + token = top_k_sampling(logits, 2) + tokens.add(token.item()) + self.assertEqual(tokens, {0, 3}) + + # Batch mode works + probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) + logits = mx.log(probs) + + tokens = top_k_sampling(logits, 1) + self.assertEqual(tokens.tolist(), [0, 1]) + if __name__ == "__main__": unittest.main()