diff --git a/llms/README.md b/llms/README.md index 4976c39e..60f68353 100644 --- a/llms/README.md +++ b/llms/README.md @@ -100,8 +100,9 @@ To see a description of all the arguments you can do: #### Streaming -For streaming generation, use the `stream_generate` function. This returns a -generator object which streams the output text, token, and log probabilities. +For streaming generation, use the `stream_generate` function. This yields +a generation response object. + For example, ```python diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 1c135ad1..7795d8d7 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -6,6 +6,7 @@ import json import mlx.core as mx from .models.cache import make_prompt_cache +from .sample_utils import make_sampler from .utils import load, stream_generate DEFAULT_TEMP = 0.0 @@ -79,8 +80,7 @@ def main(): tokenizer, prompt, args.max_tokens, - temp=args.temp, - top_p=args.top_p, + sampler=make_sampler(args.temp, args.top_p), prompt_cache=prompt_cache, ): print(response.text, flush=True, end="") diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py index 3bf01688..c7512b3c 100644 --- a/llms/mlx_lm/examples/chat.py +++ b/llms/mlx_lm/examples/chat.py @@ -42,7 +42,6 @@ response = generate( tokenizer, prompt=prompt, verbose=True, - temp=0.0, prompt_cache=prompt_cache, ) diff --git a/llms/mlx_lm/examples/generate_response.py b/llms/mlx_lm/examples/generate_response.py index 25730617..e6535b47 100644 --- a/llms/mlx_lm/examples/generate_response.py +++ b/llms/mlx_lm/examples/generate_response.py @@ -23,14 +23,6 @@ max_tokens = 1_000 # Specify if tokens and timing information will be printed verbose = True -# Some optional arguments for causal language model generation -generation_args = { - "temp": 0.7, - "repetition_penalty": 1.2, - "repetition_context_size": 20, - "top_p": 0.95, -} - # Generate a response with the specified settings response = generate( model=model, @@ -38,5 +30,4 @@ response = generate( prompt=prompt, max_tokens=max_tokens, verbose=verbose, - **generation_args, ) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 09849632..9e96fbdc 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -7,6 +7,7 @@ import sys import mlx.core as mx from .models.cache import QuantizedKVCache, load_prompt_cache +from .sample_utils import make_sampler from .utils import generate, load DEFAULT_PROMPT = "hello" @@ -218,16 +219,14 @@ def main(): else: prompt = args.prompt + sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate( model, tokenizer, prompt, max_tokens=args.max_tokens, verbose=args.verbose, - temp=args.temp, - top_p=args.top_p, - min_p=args.min_p, - min_tokens_to_keep=args.min_tokens_to_keep, + sampler=sampler, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, kv_bits=args.kv_bits, diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index d96a3f72..badc6dd3 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -27,6 +27,7 @@ from huggingface_hub import scan_cache_dir from ._version import __version__ from .models.cache import make_prompt_cache +from .sample_utils import make_logits_processors, make_sampler from .utils import load, stream_generate @@ -464,15 +465,17 @@ class APIHandler(BaseHTTPRequestHandler): text = "" tic = time.perf_counter() + sampler = make_sampler(self.temperature) + logits_processors = make_logits_processors( + self.logit_bias, self.repetition_penalty, self.repetition_context_size + ) for gen_response in stream_generate( model=self.model, tokenizer=self.tokenizer, prompt=prompt, max_tokens=self.max_tokens, - temp=self.temperature, - repetition_penalty=self.repetition_penalty, - repetition_context_size=self.repetition_context_size, - logit_bias=self.logit_bias, + sampler=sampler, + logits_processors=logits_processors, prompt_cache=self.prompt_cache.cache, ): segment = gen_response.text diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 9d390733..0fa41ac0 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -73,16 +73,16 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer): def reset(self): self.offset = 0 - self._tokens = [] + self.tokens = [] self._text = "" self._current_tokens = [] self._current_text = "" def add_token(self, token): self._current_tokens.append(token) + self.tokens.append(token) def finalize(self): - self._tokens.extend(self._current_tokens) self._text += self._tokenizer.decode(self._current_tokens) self._current_tokens = [] self._current_text = "" @@ -97,16 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer): ): self._current_text = self._current_text[:-1] if self._current_text and self._current_text[-1] == "\n": - self._tokens.extend(self._current_tokens) self._text += self._current_text self._current_tokens.clear() self._current_text = "" return self._text + self._current_text - @property - def tokens(self): - return self._tokens - class SPMStreamingDetokenizer(StreamingDetokenizer): """A streaming detokenizer for SPM models. @@ -143,6 +138,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): self.text += text def add_token(self, token): + self.tokens.append(token) v = self.tokenmap[token] if v.startswith(self._sep): self._flush() @@ -200,6 +196,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): return current_text def add_token(self, token): + self.tokens.append(token) v = self.tokenmap[token] is_added = token in self._added_ids if is_added or self._byte_decoder[v[0]] == 32: diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 15b0af2d..496ae4fc 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -182,20 +182,21 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ def generate_step( prompt: mx.array, model: nn.Module, - temp: float = 0.0, - repetition_penalty: Optional[float] = None, - repetition_context_size: Optional[int] = 20, - top_p: float = 1.0, - min_p: float = 0.0, - min_tokens_to_keep: int = 1, - prefill_step_size: int = 512, + *, + sampler: Optional[Callable[mx.array, mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, - logit_bias: Optional[Dict[int, float]] = None, - logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prefill_step_size: int = 512, kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, + temp: Optional[float] = None, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = None, + top_p: Optional[float] = None, + min_p: Optional[float] = None, + min_tokens_to_keep: Optional[int] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -203,32 +204,21 @@ def generate_step( Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - temp (float): The temperature for sampling, if 0 the argmax is used. - Default: ``0``. - repetition_penalty (float, optional): The penalty factor for repeating - tokens. - repetition_context_size (int, optional): The number of tokens to - consider for repetition penalty. Default: ``20``. - top_p (float, optional): Nulceus sampling, higher means model considers - more less likely words. - min_p (float, optional): The minimum value (scaled by the top token's - probability) that a token probability must have to be considered. - min_tokens_to_keep (int, optional): Minimum number of tokens that cannot - be filtered by min_p sampling. prefill_step_size (int): Step size for processing the prompt. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if provided, the cache will be updated in place. - logit_bias (dictionary, optional): Additive logit bias. + sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a + token from a vector of log probabilities. Default: ``None``. logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): - A list of functions that take tokens and logits and return the processed - logits. Default: ``None``. + A list of functions that take tokens and logits and return the processed + logits. Default: ``None``. kv_bits (int, optional): Number of bits to use for KV cache quantization. - None implies no cache quantization. Default: ``None``. + None implies no cache quantization. Default: ``None``. kv_group_size (int): Group size for KV cache quantization. Default: ``64``. quantized_kv_start (int): Step to begin using a quantized KV cache. - when ``kv_bits`` is non-None. Default: ``0``. + when ``kv_bits`` is non-None. Default: ``0``. Yields: Tuple[mx.array, mx.array]: One token and a vector of log probabilities. @@ -246,10 +236,22 @@ def generate_step( elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") - sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) - logits_processors = logits_processors or [] - logits_processors.extend( - make_logits_processors(logit_bias, repetition_penalty, repetition_context_size) + if temp is not None or top_p is not None or min_tokens_to_keep is not None: + print( + "[Warning] Specifying sampling arguments to ``generate_step`` is " + "deprecated. Pass in a ``sampler`` instead." + ) + if repetition_penalty is not None: + print( + "[Warning] Specifying ``repetition_penalty`` is deprecated. " + "Pass in ``logits_processors`` instead." + ) + + sampler = sampler or make_sampler( + temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1 + ) + logits_processors = logits_processors or make_logits_processors( + None, repetition_penalty, repetition_context_size or 20 ) def _step(y): @@ -385,7 +387,10 @@ def generate( See :func:`stream_generate` for more details. """ if formatter is not None: - print("Text formatting is deprecated and will be removed in the next version.") + print( + "[Warning] Text formatting is deprecated and no longer used. " + "The argument will be removed in a future version." + ) if verbose: print("=" * 10) print("Prompt:", prompt) diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py index 9c30d51e..db6b9f9e 100644 --- a/llms/tests/test_tokenizers.py +++ b/llms/tests/test_tokenizers.py @@ -34,10 +34,11 @@ class TestTokenizers(unittest.TestCase): detokenizer = tokenizer.detokenizer detokenizer.reset() text = "" - for t in tokens: + for e, t in enumerate(tokens): detokenizer.add_token(t) seg = detokenizer.last_segment text += seg + self.assertEqual(detokenizer.tokens, tokens[: e + 1]) detokenizer.finalize() text += detokenizer.last_segment self.assertEqual(text, expected_text)