From 0be87b3c53795d7b53e44fcd3017e9e7b1dbc441 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 5 Nov 2024 17:01:21 -0800 Subject: [PATCH] refactor sampler/processor and a few improvements --- llms/README.md | 5 +- llms/mlx_lm/cache_prompt.py | 4 +- llms/mlx_lm/generate.py | 19 +++- llms/mlx_lm/sample_utils.py | 26 ++++- llms/mlx_lm/server.py | 183 ++++++++++---------------------- llms/mlx_lm/tuner/trainer.py | 2 +- llms/mlx_lm/utils.py | 74 ++++++++----- llms/tests/test_generate.py | 2 +- llms/tests/test_prompt_cache.py | 2 +- 9 files changed, 153 insertions(+), 164 deletions(-) diff --git a/llms/README.md b/llms/README.md index 0e7dc7fb..eeb3ed6a 100644 --- a/llms/README.md +++ b/llms/README.md @@ -101,7 +101,8 @@ To see a description of all the arguments you can do: #### Streaming For streaming generation, use the `stream_generate` function. This returns a -generator object which streams the output text. For example, +generator object which streams the output text, token, and log probabilities. +For example, ```python from mlx_lm import load, stream_generate @@ -116,7 +117,7 @@ prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) -for t in stream_generate(model, tokenizer, prompt, max_tokens=512): +for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512): print(t, end="", flush=True) print() ``` diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 7bb06411..987b640d 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -152,6 +152,7 @@ def main(): model(y[:step_size][None], cache=cache) mx.eval([c.state for c in cache]) + mx.metal.clear_cache() processed += min(y.size, step_size) y = y[step_size:] current = time.time() @@ -165,14 +166,13 @@ def main(): ) print() - print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") + print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB") print("Saving...") metadata = {} metadata["model"] = args.model metadata["chat_template"] = tokenizer.chat_template metadata["tokenizer_config"] = json.dumps(tokenizer_config) - print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") save_prompt_cache(args.prompt_cache_file, cache, metadata) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 1820dd36..51169def 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -13,6 +13,8 @@ DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 +DEFAULT_MIN_P = 0.0 +DEFAULT_MIN_TOKENS_TO_KEEP = 1 DEFAULT_SEED = 0 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_QUANTIZED_KV_START = 5000 @@ -52,6 +54,7 @@ def setup_arg_parser(): ) parser.add_argument( "--prompt", + "-p", default=DEFAULT_PROMPT, help="Message to be processed by the model ('-' reads from stdin)", ) @@ -68,6 +71,15 @@ def setup_arg_parser(): parser.add_argument( "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" ) + parser.add_argument( + "--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p" + ) + parser.add_argument( + "--min-tokens-to-keep", + type=float, + default=DEFAULT_MIN_TOKENS_TO_KEEP, + help="Minimum tokens to keep for min-p sampling.", + ) parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") parser.add_argument( "--ignore-chat-template", @@ -238,8 +250,6 @@ def main(): raise ValueError("Cannot use --colorize with --verbose=False") formatter = colorprint_by_t0 if args.colorize else None - sampler = make_sampler( - args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate( model, tokenizer, @@ -247,7 +257,10 @@ def main(): args.max_tokens, verbose=args.verbose, formatter=formatter, - sampler=sampler, + temp=args.temp, + top_p=args.top_p, + min_p=args.min_p, + min_tokens_to_keep=args.min_tokens_to_keep, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, kv_bits=args.kv_bits, diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index f1a5c1bb..c27b52d8 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -1,6 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from functools import partial +from typing import Callable, Dict, Optional import mlx.core as mx @@ -25,7 +26,7 @@ def make_sampler( be filtered by min_p sampling. Returns: - Callabel[mx.array, mx.array]: + Callable[mx.array, mx.array]: A sampler which takes log-probabilities and returns tokens. """ if temp == 0: @@ -38,7 +39,11 @@ def make_sampler( return lambda x: categorical_sampling(x, temp) -def make_logits_processors(): +def make_logits_processors( + logit_bias: Optional[Dict[int, float]] = None, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = 20, +): """ Make logits processors for use with ``generate_step``. @@ -48,8 +53,13 @@ def make_logits_processors(): repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. Default: ``20``. logit_bias (dictionary, optional): Additive logit bias. - """ + Returns: + List[Callable[[mx.array, mx.array], mx.array]]: + A list of logits processors. Each processor in the list is a + callable which takes an array of tokens and an array of logits + and returns the updated logits. + """ logits_processors = [] if logit_bias: indices = mx.array(list(logit_bias.keys())) @@ -61,6 +71,12 @@ def make_logits_processors(): logits_processors.append(logit_bias_processor) + if repetition_penalty and repetition_penalty != 0.0: + logits_processors.append( + make_repetition_penalty(repetition_penalty, repetition_context_size) + ) + return logits_processors + @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def min_p_sampling( @@ -159,7 +175,7 @@ def categorical_sampling(logits, temp): return mx.random.categorical(logits * (1 / temp)) -def repetition_penalty(penalty: float, context_size: int = 20): +def make_repetition_penalty(penalty: float, context_size: int = 20): """ Make repetition penalty processor. @@ -177,7 +193,7 @@ def repetition_penalty(penalty: float, context_size: int = 20): if penalty < 0 or not isinstance(penalty, float): raise ValueError(f"penalty must be a non-negative float, got {penalty}") - def repetition_penalty_processor(logits, tokens): + def repetition_penalty_processor(tokens, logits): if len(tokens) > 0: tokens = tokens[-context_size:] selected_logits = logits[:, tokens] diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index e0d0921c..9c949291 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -27,7 +27,7 @@ from huggingface_hub import scan_cache_dir from ._version import __version__ from .models.cache import make_prompt_cache -from .utils import generate_step, load +from .utils import load, stream_generate def get_system_fingerprint(): @@ -290,10 +290,7 @@ class APIHandler(BaseHTTPRequestHandler): # Call endpoint specific method prompt = endpoints[self.path]() - - # Call method based on response type - method = self.handle_stream if self.stream else self.handle_completion - method(prompt, stop_id_sequences) + self.handle_completion(prompt, stop_id_sequences) def validate_model_parameters(self): """ @@ -452,25 +449,28 @@ class APIHandler(BaseHTTPRequestHandler): stop_id_sequences (List[List[int]]): A list of stop words passed to the stopping_criteria function """ - detokenizer = self.tokenizer.detokenizer - detokenizer.reset() tokens = [] finish_reason = "length" stop_sequence_suffix = None - logging.debug(f"Starting completion:") + if self.stream: + self.end_headers() + logging.debug(f"Starting stream:") + else: + logging.debug(f"Starting completion:") token_logprobs = [] top_tokens = [] - prompt = mx.array(self.get_prompt_cache(prompt)) + prompt = self.get_prompt_cache(prompt) + text = "" tic = time.perf_counter() - for _, (token, logprobs) in zip( - range(self.max_tokens), - generate_step( - prompt=prompt, + for n, (segment, token, logprobs) in enumerate( + stream_generate( model=self.model, + tokenizer=self.tokenizer, + prompt=prompt, + max_tokens=self.max_tokens, temp=self.temperature, - top_p=self.top_p, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, logit_bias=self.logit_bias, @@ -481,8 +481,8 @@ class APIHandler(BaseHTTPRequestHandler): prompt_time = time.perf_counter() - tic tic = time.perf_counter() - detokenizer.add_token(token) - logging.debug(detokenizer.text) + text += segment + logging.debug(text) tokens.append(token) if self.logprobs > 0: @@ -503,128 +503,63 @@ class APIHandler(BaseHTTPRequestHandler): stop_sequence_suffix = self.tokenizer.decode( tokens[-stop_condition.trim_length :] ) + text = text[: -len(stop_sequence_suffix)] break + if self.stream: + # If the end of tokens overlaps with a stop sequence, generate new + # tokens until we know if the stop sequence is hit or not + if any( + ( + sequence_overlap(tokens, sequence) + for sequence in stop_id_sequences + ) + ): + continue + elif segment: + response = self.generate_response(segment, None) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + self.prompt_cache.tokens.extend(tokens) - detokenizer.finalize() - text = ( - detokenizer.text - if stop_sequence_suffix is None - else detokenizer.text[: -len(stop_sequence_suffix)] - ) + gen_time = time.perf_counter() - tic prompt_tps = len(prompt) / prompt_time gen_tps = len(tokens) / gen_time peak_mem = mx.metal.get_peak_memory() / 1e9 - response = self.generate_response( - text, - finish_reason, - len(prompt), - len(tokens), - token_logprobs=token_logprobs, - top_tokens=top_tokens, - tokens=tokens, - ) - logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec") logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec") logging.debug(f"Peak memory: {peak_mem:.3f} GB") - response_json = json.dumps(response).encode() - indent = "\t" # Backslashes can't be inside of f-strings - logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") - # Send an additional Content-Length header when it is known - self.send_header("Content-Length", str(len(response_json))) - self.end_headers() - - self.wfile.write(response_json) - self.wfile.flush() - - def handle_stream( - self, - prompt: List[int], - stop_id_sequences: List[List[int]], - ): - """ - Generate response to prompt and foward it to the client using a Server - Sent Events (SSE) stream. - - Args: - prompt (mx.array): The tokenized prompt - stop_id_sequences (List[List[int]]): A list of stop words passed to - the stopping_criteria function - """ - # No additional headers are needed, call end_headers - self.end_headers() - - detokenizer = self.tokenizer.detokenizer - detokenizer.reset() - tokens = [] - - stop_sequence_suffix = None - logging.debug(f"Starting stream:") - - prompt = mx.array(self.get_prompt_cache(prompt)) - - for _, (token, _) in zip( - range(self.max_tokens), - generate_step( - prompt=prompt, - model=self.model, - temp=self.temperature, - top_p=self.top_p, - repetition_penalty=self.repetition_penalty, - repetition_context_size=self.repetition_context_size, - prompt_cache=self.prompt_cache.cache, - ), - ): - detokenizer.add_token(token) - logging.debug(detokenizer.text) - tokens.append(token) - - stop_condition = stopping_criteria( - tokens, - stop_id_sequences, - self.tokenizer.eos_token_id, - ) - if stop_condition.stop_met: - if stop_condition.trim_length: - stop_sequence_suffix = self.tokenizer.decode( - tokens[-stop_condition.trim_length :] - ) - break - - # If the end of tokens overlaps with a stop sequence, generate new - # tokens until we know if the stop sequence is hit or not - if any( - (sequence_overlap(tokens, sequence) for sequence in stop_id_sequences) - ): - continue - - new_text = detokenizer.last_segment - if new_text: - response = self.generate_response(new_text, None) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() - - self.prompt_cache.tokens.extend(tokens) - - # check is there any remaining text to send - detokenizer.finalize() - last_segment = detokenizer.last_segment - if last_segment: - if stop_sequence_suffix is not None: - last_segment = last_segment[: -len(stop_sequence_suffix)] - response = self.generate_response(last_segment, "length") + if self.stream: + response = self.generate_response(segment, finish_reason) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.flush() + if self.stream_options is not None and self.stream_options["include_usage"]: + response = self.completion_usage_response(len(prompt), len(tokens)) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + self.wfile.write("data: [DONE]\n\n".encode()) + self.wfile.flush() + else: + response = self.generate_response( + text, + finish_reason, + len(prompt), + len(tokens), + token_logprobs=token_logprobs, + top_tokens=top_tokens, + tokens=tokens, + ) + response_json = json.dumps(response).encode() + indent = "\t" # Backslashes can't be inside of f-strings + logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") - if self.stream_options is not None and self.stream_options["include_usage"]: - response = self.completion_usage_response(len(prompt), len(tokens)) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - - self.wfile.write("data: [DONE]\n\n".encode()) - self.wfile.flush() + # Send an additional Content-Length header when it is known + self.send_header("Content-Length", str(len(response_json))) + self.end_headers() + self.wfile.write(response_json) + self.wfile.flush() def completion_usage_response( self, diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 38619d95..21b1af18 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -285,7 +285,7 @@ def train( it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens - peak_mem = mx.metal.get_peak_memory() / 2**30 + peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: print( f"Iter {it}: Train loss {train_loss:.3f}, " diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 1f1da440..9f852ae3 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer # Local imports from .models import cache -from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling +from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model from .tuner.utils import load_adapters @@ -35,7 +35,8 @@ MODEL_REMAPPING = { MAX_FILE_SIZE_GB = 5 # A stream on the default device just for generation -generation_stream = mx.new_stream(mx.default_device()) +# generation_stream = mx.new_stream(mx.default_device()) +generation_stream = mx.default_stream(mx.default_device()) class ModelNotFoundError(Exception): @@ -155,10 +156,16 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ def generate_step( prompt: mx.array, model: nn.Module, + temp: float = 0.0, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = 20, + top_p: float = 1.0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, prefill_step_size: int = 512, max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, - sampler: Optional[Callable[mx.array, mx.array]] = None, + logit_bias: Optional[Dict[int, float]] = None, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, kv_bits: Optional[int] = None, kv_group_size: int = 64, @@ -170,14 +177,24 @@ def generate_step( Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. + temp (float): The temperature for sampling, if 0 the argmax is used. + Default: ``0``. + repetition_penalty (float, optional): The penalty factor for repeating + tokens. + repetition_context_size (int, optional): The number of tokens to + consider for repetition penalty. Default: ``20``. + top_p (float, optional): Nulceus sampling, higher means model considers + more less likely words. + min_p (float, optional): The minimum value (scaled by the top token's + probability) that a token probability must have to be considered. + min_tokens_to_keep (int, optional): Minimum number of tokens that cannot + be filtered by min_p sampling. prefill_step_size (int): Step size for processing the prompt. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if provided, the cache will be updated in place. - sampler (Callable[mx.array, mx.array], optional). A function which - takes log probabilities and returns tokens. If ``None`` then the - argmax is used. Default: ``None``. + logit_bias (dictionary, optional): Additive logit bias. logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed logits. Default: ``None``. @@ -204,7 +221,11 @@ def generate_step( elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") - sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) + logits_processors = logits_processors or [] + logits_processors.extend( + make_logits_processors(logit_bias, repetition_penalty, repetition_context_size) + ) def _step(y): with mx.stream(generation_stream): @@ -222,7 +243,7 @@ def generate_step( prompt_cache, quantized_kv_start, kv_group_size, kv_bits ) - logprobs = logits - mx.logsumexp(logits) + logprobs = logits - mx.logsumexp(logits, keepdims=True) y = sampler(logprobs) return y, logprobs.squeeze(0) @@ -249,7 +270,7 @@ def generate_step( def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: str, + prompt: Union[str, List[int]], max_tokens: int = 100, **kwargs, ) -> Union[str, Generator[str, None, None]]: @@ -257,7 +278,7 @@ def stream_generate( A generator producing text based on the given prompt from the model. Args: - prompt (mx.array): The input prompt. + prompt (Union[str, List[int]]): The input prompt. model (nn.Module): The model to use for generation. max_tokens (int): The ma kwargs: The remaining options get passed to :func:`generate_step`. @@ -269,23 +290,26 @@ def stream_generate( if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - prompt_tokens = mx.array(tokenizer.encode(prompt)) + prompt_tokens = mx.array( + prompt if isinstance(prompt, list) else tokenizer.encode(prompt) + ) detokenizer = tokenizer.detokenizer - detokenizer.reset() - for n, (token, _) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if token == tokenizer.eos_token_id: - break - detokenizer.add_token(token) + with wired_limit(model, [generation_stream]): + detokenizer.reset() + for n, (token, logits) in zip( + range(max_tokens), + generate_step(prompt_tokens, model, **kwargs), + ): + if token == tokenizer.eos_token_id: + break + detokenizer.add_token(token) - # Yield the last segment if streaming - yield detokenizer.last_segment + # Yield the last segment if streaming + yield detokenizer.last_segment, token, logits - detokenizer.finalize() - yield detokenizer.last_segment + detokenizer.finalize() + yield detokenizer.last_segment, token, logits def generate( @@ -322,7 +346,7 @@ def generate( prompt_tokens = mx.array(tokenizer.encode(prompt)) detokenizer = tokenizer.detokenizer - with wired_limit(model): + with wired_limit(model, [generation_stream]): tic = time.perf_counter() detokenizer.reset() for n, (token, logprobs) in zip( @@ -361,7 +385,7 @@ def generate( f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" ) print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 2**30 + peak_mem = mx.metal.get_peak_memory() / 1e9 print(f"Peak memory: {peak_mem:.3f} GB") return detokenizer.text diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index 68f1670b..e0a372a9 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase): "hello", max_tokens=5, verbose=False, - logits_processor=[logits_processor], + logits_processors=[logits_processor], ) self.assertEqual(len(all_toks), len(init_toks) + 5) diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 1e57bd86..0867ab56 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -299,7 +299,7 @@ class TestPromptCache(unittest.TestCase): ): i += 1 self.assertEqual(tok, toks[i]) - self.assertTrue(mx.allclose(logits, all_logits[i], rtol=1e-2)) + self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2)) if __name__ == "__main__":