diff --git a/llms/README.md b/llms/README.md index 0e7dc7fb..eeb3ed6a 100644 --- a/llms/README.md +++ b/llms/README.md @@ -101,7 +101,8 @@ To see a description of all the arguments you can do: #### Streaming For streaming generation, use the `stream_generate` function. This returns a -generator object which streams the output text. For example, +generator object which streams the output text, token, and log probabilities. +For example, ```python from mlx_lm import load, stream_generate @@ -116,7 +117,7 @@ prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) -for t in stream_generate(model, tokenizer, prompt, max_tokens=512): +for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512): print(t, end="", flush=True) print() ``` diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 7bb06411..987b640d 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -152,6 +152,7 @@ def main(): model(y[:step_size][None], cache=cache) mx.eval([c.state for c in cache]) + mx.metal.clear_cache() processed += min(y.size, step_size) y = y[step_size:] current = time.time() @@ -165,14 +166,13 @@ def main(): ) print() - print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") + print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB") print("Saving...") metadata = {} metadata["model"] = args.model metadata["chat_template"] = tokenizer.chat_template metadata["tokenizer_config"] = json.dumps(tokenizer_config) - print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") save_prompt_cache(args.prompt_cache_file, cache, metadata) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 85d32d5f..c03056a6 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -74,7 +74,7 @@ def main(): prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - for response in stream_generate( + for response, *_ in stream_generate( model, tokenizer, prompt, diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 29976da2..51169def 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -13,6 +13,8 @@ DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 +DEFAULT_MIN_P = 0.0 +DEFAULT_MIN_TOKENS_TO_KEEP = 1 DEFAULT_SEED = 0 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_QUANTIZED_KV_START = 5000 @@ -52,6 +54,7 @@ def setup_arg_parser(): ) parser.add_argument( "--prompt", + "-p", default=DEFAULT_PROMPT, help="Message to be processed by the model ('-' reads from stdin)", ) @@ -68,6 +71,15 @@ def setup_arg_parser(): parser.add_argument( "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" ) + parser.add_argument( + "--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p" + ) + parser.add_argument( + "--min-tokens-to-keep", + type=float, + default=DEFAULT_MIN_TOKENS_TO_KEEP, + help="Minimum tokens to keep for min-p sampling.", + ) parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") parser.add_argument( "--ignore-chat-template", @@ -247,6 +259,8 @@ def main(): formatter=formatter, temp=args.temp, top_p=args.top_p, + min_p=args.min_p, + min_tokens_to_keep=args.min_tokens_to_keep, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, kv_bits=args.kv_bits, diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 20b008fa..c27b52d8 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -1,10 +1,83 @@ # Copyright © 2023-2024 Apple Inc. from functools import partial +from typing import Callable, Dict, Optional import mlx.core as mx +def make_sampler( + temp: float = 0.0, + top_p: float = 0.0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, +) -> Callable[mx.array, mx.array]: + """ + Make a sampler function for use with ``generate_step``. + + Args: + temp (float): The temperature for sampling, if 0 the argmax is used. + Default: ``0``. + top_p (float, optional): Nulceus sampling, higher means model considers + more less likely words. + min_p (float, optional): The minimum value (scaled by the top token's + probability) that a token probability must have to be considered. + min_tokens_to_keep (int, optional): Minimum number of tokens that cannot + be filtered by min_p sampling. + + Returns: + Callable[mx.array, mx.array]: + A sampler which takes log-probabilities and returns tokens. + """ + if temp == 0: + return lambda x: mx.argmax(x, axis=-1) + elif top_p > 0 and top_p < 1.0: + return lambda x: top_p_sampling(x, top_p, temp) + elif min_p != 0.0: + return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp) + else: + return lambda x: categorical_sampling(x, temp) + + +def make_logits_processors( + logit_bias: Optional[Dict[int, float]] = None, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = 20, +): + """ + Make logits processors for use with ``generate_step``. + + Args: + repetition_penalty (float, optional): The penalty factor for repeating + tokens. + repetition_context_size (int, optional): The number of tokens to + consider for repetition penalty. Default: ``20``. + logit_bias (dictionary, optional): Additive logit bias. + + Returns: + List[Callable[[mx.array, mx.array], mx.array]]: + A list of logits processors. Each processor in the list is a + callable which takes an array of tokens and an array of logits + and returns the updated logits. + """ + logits_processors = [] + if logit_bias: + indices = mx.array(list(logit_bias.keys())) + values = mx.array(list(logit_bias.values())) + + def logit_bias_processor(_, logits): + logits[:, indices] += values + return logits + + logits_processors.append(logit_bias_processor) + + if repetition_penalty and repetition_penalty != 0.0: + logits_processors.append( + make_repetition_penalty(repetition_penalty, repetition_context_size) + ) + return logits_processors + + @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def min_p_sampling( logits: mx.array, @@ -100,3 +173,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def categorical_sampling(logits, temp): return mx.random.categorical(logits * (1 / temp)) + + +def make_repetition_penalty(penalty: float, context_size: int = 20): + """ + Make repetition penalty processor. + + Paper: https://arxiv.org/abs/1909.05858 + + Args: + penalty (float): The repetition penalty factor to be applied. + context_size (int): The number of previous tokens to use. + Default: ``20``. + + Returns: + Callable[[mx.array, List[int]], mx.array]: + The repetition penalty processor. + """ + if penalty < 0 or not isinstance(penalty, float): + raise ValueError(f"penalty must be a non-negative float, got {penalty}") + + def repetition_penalty_processor(tokens, logits): + if len(tokens) > 0: + tokens = tokens[-context_size:] + selected_logits = logits[:, tokens] + selected_logits = mx.where( + selected_logits < 0, + selected_logits * penalty, + selected_logits / penalty, + ) + logits[:, tokens] = selected_logits + return logits + + return repetition_penalty_processor diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index ec659969..c1365b36 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -27,7 +27,7 @@ from huggingface_hub import scan_cache_dir from ._version import __version__ from .models.cache import make_prompt_cache -from .utils import generate_step, load +from .utils import load, stream_generate def get_system_fingerprint(): @@ -64,7 +64,7 @@ def stopping_criteria( end if it has (`trim_length`). """ if tokens and tokens[-1] == eos_token_id: - return StopCondition(stop_met=True, trim_length=1) + return StopCondition(stop_met=True, trim_length=0) for stop_ids in stop_id_sequences: if len(tokens) >= len(stop_ids): @@ -253,7 +253,7 @@ class APIHandler(BaseHTTPRequestHandler): self.max_tokens = self.body.get("max_completion_tokens", None) if self.max_tokens is None: self.max_tokens = self.body.get("max_tokens", 512) - self.temperature = self.body.get("temperature", 1.0) + self.temperature = self.body.get("temperature", 0.0) self.top_p = self.body.get("top_p", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0) self.repetition_context_size = self.body.get("repetition_context_size", 20) @@ -290,10 +290,7 @@ class APIHandler(BaseHTTPRequestHandler): # Call endpoint specific method prompt = endpoints[self.path]() - - # Call method based on response type - method = self.handle_stream if self.stream else self.handle_completion - method(prompt, stop_id_sequences) + self.handle_completion(prompt, stop_id_sequences) def validate_model_parameters(self): """ @@ -452,32 +449,40 @@ class APIHandler(BaseHTTPRequestHandler): stop_id_sequences (List[List[int]]): A list of stop words passed to the stopping_criteria function """ - detokenizer = self.tokenizer.detokenizer - detokenizer.reset() tokens = [] finish_reason = "length" stop_sequence_suffix = None - logging.debug(f"Starting completion:") + if self.stream: + self.end_headers() + logging.debug(f"Starting stream:") + else: + logging.debug(f"Starting completion:") token_logprobs = [] top_tokens = [] prompt = self.get_prompt_cache(prompt) - for _, (token, logprobs) in zip( - range(self.max_tokens), - generate_step( - prompt=mx.array(prompt), + text = "" + tic = time.perf_counter() + for n, (segment, token, logprobs) in enumerate( + stream_generate( model=self.model, + tokenizer=self.tokenizer, + prompt=prompt, + max_tokens=self.max_tokens, temp=self.temperature, - top_p=self.top_p, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, logit_bias=self.logit_bias, prompt_cache=self.prompt_cache.cache, ), ): - detokenizer.add_token(token) - logging.debug(detokenizer.text) + if n == 0: + prompt_time = time.perf_counter() - tic + tic = time.perf_counter() + + text += segment + logging.debug(text) tokens.append(token) if self.logprobs > 0: @@ -498,121 +503,63 @@ class APIHandler(BaseHTTPRequestHandler): stop_sequence_suffix = self.tokenizer.decode( tokens[-stop_condition.trim_length :] ) + text = text[: -len(stop_sequence_suffix)] break - self.prompt_cache.tokens.extend(tokens) - detokenizer.finalize() - text = ( - detokenizer.text - if stop_sequence_suffix is None - else detokenizer.text[: -len(stop_sequence_suffix)] - ) - response = self.generate_response( - text, - finish_reason, - len(prompt), - len(tokens), - token_logprobs=token_logprobs, - top_tokens=top_tokens, - tokens=tokens, - ) - - response_json = json.dumps(response).encode() - indent = "\t" # Backslashes can't be inside of f-strings - logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") - - # Send an additional Content-Length header when it is known - self.send_header("Content-Length", str(len(response_json))) - self.end_headers() - - self.wfile.write(response_json) - self.wfile.flush() - - def handle_stream( - self, - prompt: List[int], - stop_id_sequences: List[List[int]], - ): - """ - Generate response to prompt and foward it to the client using a Server - Sent Events (SSE) stream. - - Args: - prompt (mx.array): The tokenized prompt - stop_id_sequences (List[List[int]]): A list of stop words passed to - the stopping_criteria function - """ - # No additional headers are needed, call end_headers - self.end_headers() - - detokenizer = self.tokenizer.detokenizer - detokenizer.reset() - tokens = [] - - stop_sequence_suffix = None - logging.debug(f"Starting stream:") - - prompt = self.get_prompt_cache(prompt) - - for _, (token, _) in zip( - range(self.max_tokens), - generate_step( - prompt=mx.array(prompt), - model=self.model, - temp=self.temperature, - top_p=self.top_p, - repetition_penalty=self.repetition_penalty, - repetition_context_size=self.repetition_context_size, - prompt_cache=self.prompt_cache.cache, - ), - ): - detokenizer.add_token(token) - logging.debug(detokenizer.text) - tokens.append(token) - - stop_condition = stopping_criteria( - tokens, - stop_id_sequences, - self.tokenizer.eos_token_id, - ) - if stop_condition.stop_met: - if stop_condition.trim_length: - stop_sequence_suffix = self.tokenizer.decode( - tokens[-stop_condition.trim_length :] + if self.stream: + # If the end of tokens overlaps with a stop sequence, generate new + # tokens until we know if the stop sequence is hit or not + if any( + ( + sequence_overlap(tokens, sequence) + for sequence in stop_id_sequences ) - break - - # If the end of tokens overlaps with a stop sequence, generate new - # tokens until we know if the stop sequence is hit or not - if any( - (sequence_overlap(tokens, sequence) for sequence in stop_id_sequences) - ): - continue - - new_text = detokenizer.last_segment - if new_text: - response = self.generate_response(new_text, None) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() + ): + continue + elif segment: + response = self.generate_response(segment, None) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() self.prompt_cache.tokens.extend(tokens) - # check is there any remaining text to send - detokenizer.finalize() - last_segment = detokenizer.last_segment - if last_segment: - if stop_sequence_suffix is not None: - last_segment = last_segment[: -len(stop_sequence_suffix)] - response = self.generate_response(last_segment, "length") + gen_time = time.perf_counter() - tic + prompt_tps = len(prompt) / prompt_time + gen_tps = len(tokens) / gen_time + peak_mem = mx.metal.get_peak_memory() / 1e9 + logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec") + logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec") + logging.debug(f"Peak memory: {peak_mem:.3f} GB") + + if self.stream: + response = self.generate_response(segment, finish_reason) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.flush() + if self.stream_options is not None and self.stream_options["include_usage"]: + response = self.completion_usage_response(len(prompt), len(tokens)) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + self.wfile.write("data: [DONE]\n\n".encode()) + self.wfile.flush() + else: + response = self.generate_response( + text, + finish_reason, + len(prompt), + len(tokens), + token_logprobs=token_logprobs, + top_tokens=top_tokens, + tokens=tokens, + ) + response_json = json.dumps(response).encode() + indent = "\t" # Backslashes can't be inside of f-strings + logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") - if self.stream_options is not None and self.stream_options["include_usage"]: - response = self.completion_usage_response(len(prompt), len(tokens)) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - - self.wfile.write("data: [DONE]\n\n".encode()) - self.wfile.flush() + # Send an additional Content-Length header when it is known + self.send_header("Content-Length", str(len(response_json))) + self.end_headers() + self.wfile.write(response_json) + self.wfile.flush() def completion_usage_response( self, diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 38619d95..21b1af18 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -285,7 +285,7 @@ def train( it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens - peak_mem = mx.metal.get_peak_memory() / 2**30 + peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: print( f"Iter {it}: Train loss {train_loss:.3f}, " diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 7b440db6..8893b570 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer # Local imports from .models import cache -from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling +from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model from .tuner.utils import load_adapters @@ -34,6 +34,9 @@ MODEL_REMAPPING = { MAX_FILE_SIZE_GB = 5 +# A stream on the default device just for generation +generation_stream = mx.new_stream(mx.default_device()) + class ModelNotFoundError(Exception): def __init__(self, message): @@ -137,29 +140,6 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path return model_path -def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float): - """ - Apply repetition penalty to specific logits based on the given context. - - Paper: https://arxiv.org/abs/1909.05858 - - Args: - logits (mx.array): The logits produced by the language model. - tokens (mx.array): A list of N previous tokens. - penalty (float): The repetition penalty factor to be applied. - - Returns: - logits (mx.array): Logits with repetition penalty applied to generated tokens. - """ - if len(tokens) > 0: - selected_logits = logits[:, tokens] - selected_logits = mx.where( - selected_logits < 0, selected_logits * penalty, selected_logits / penalty - ) - logits[:, tokens] = selected_logits - return logits - - def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): if ( kv_bits is not None @@ -185,7 +165,7 @@ def generate_step( max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, logit_bias: Optional[Dict[int, float]] = None, - logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, @@ -214,7 +194,7 @@ def generate_step( prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if provided, the cache will be updated in place. logit_bias (dictionary, optional): Additive logit bias. - logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed logits. Default: ``None``. kv_bits (int, optional): Number of bits to use for KV cache quantization. @@ -224,53 +204,9 @@ def generate_step( when ``kv_bits`` is non-None. Default: ``0``. Yields: - Generator[Tuple[mx.array, mx.array], None, None]: A generator producing - one token and a vector of log probabilities. + Tuple[mx.array, mx.array]: One token and a vector of log probabilities. """ - def sample(logits: mx.array) -> Tuple[mx.array, float]: - logprobs = logits - mx.logsumexp(logits) - - if temp == 0: - token = mx.argmax(logits, axis=-1) - else: - if top_p > 0 and top_p < 1.0: - token = top_p_sampling(logits, top_p, temp) - elif min_p != 0.0: - token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp) - else: - token = categorical_sampling(logits, temp) - - return token, logprobs - - if repetition_penalty and ( - repetition_penalty < 0 or not isinstance(repetition_penalty, float) - ): - raise ValueError( - f"repetition_penalty must be a non-negative float, got {repetition_penalty}" - ) - - logits_processor = logits_processor or [] - - if repetition_penalty: - - def repetition_penalty_processor(tokens, logits): - return apply_repetition_penalty( - logits, tokens[-repetition_context_size:], repetition_penalty - ) - - logits_processor.append(repetition_penalty_processor) - - if logit_bias: - indices = mx.array(list(logit_bias.keys())) - values = mx.array(list(logit_bias.values())) - - def logit_bias_processor(_, logits): - logits[:, indices] += values - return logits - - logits_processor.append(logit_bias_processor) - y = prompt tokens = None @@ -283,24 +219,31 @@ def generate_step( elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") + sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) + logits_processors = logits_processors or [] + logits_processors.extend( + make_logits_processors(logit_bias, repetition_penalty, repetition_context_size) + ) + def _step(y): + with mx.stream(generation_stream): + logits = model(y[None], cache=prompt_cache) + logits = logits[:, -1, :] - logits = model(y[None], cache=prompt_cache) - logits = logits[:, -1, :] + if logits_processors: + nonlocal tokens + tokens = mx.concat([tokens, y]) if tokens is not None else y - if logits_processor: - nonlocal tokens - tokens = mx.concat([tokens, y]) if tokens is not None else y + for processor in logits_processors: + logits = processor(tokens, logits) - for processor in logits_processor: - logits = processor(tokens, logits) + maybe_quantize_kv_cache( + prompt_cache, quantized_kv_start, kv_group_size, kv_bits + ) - maybe_quantize_kv_cache( - prompt_cache, quantized_kv_start, kv_group_size, kv_bits - ) - - y, logprobs = sample(logits) - return y, logprobs.squeeze(0) + logprobs = logits - mx.logsumexp(logits, keepdims=True) + y = sampler(logprobs) + return y, logprobs.squeeze(0) while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) @@ -325,43 +268,51 @@ def generate_step( def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: str, + prompt: Union[str, List[int]], max_tokens: int = 100, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> Generator[Tuple[str, int, mx.array], None, None]: """ A generator producing text based on the given prompt from the model. Args: - prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - max_tokens (int): The ma + tokenizer (PreTrainedTokenizer): The tokenizer. + prompt (Union[str, List[int]]): The input prompt string or integer tokens. + max_tokens (int): The maximum number of tokens. Default: ``100``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. Yields: - Generator[Tuple[mx.array, mx.array]]: A generator producing text. + Tuple[str, int, mx.array]: + The next text segment, token, and vector of log probabilities. """ if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - prompt_tokens = mx.array(tokenizer.encode(prompt)) + prompt_tokens = mx.array( + prompt if isinstance(prompt, list) else tokenizer.encode(prompt) + ) detokenizer = tokenizer.detokenizer - detokenizer.reset() - for n, (token, _) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if token == tokenizer.eos_token_id: - break - detokenizer.add_token(token) + with wired_limit(model, [generation_stream]): + detokenizer.reset() + for n, (token, logits) in zip( + range(max_tokens), + generate_step(prompt_tokens, model, **kwargs), + ): + if token == tokenizer.eos_token_id: + break - # Yield the last segment if streaming - yield detokenizer.last_segment + detokenizer.add_token(token) - detokenizer.finalize() - yield detokenizer.last_segment + if n == (max_tokens - 1): + break + + yield detokenizer.last_segment, token, logits + + detokenizer.finalize() + yield detokenizer.last_segment, token, logits def generate( @@ -372,7 +323,7 @@ def generate( verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> str: """ Generate a complete response from the model. @@ -398,7 +349,7 @@ def generate( prompt_tokens = mx.array(tokenizer.encode(prompt)) detokenizer = tokenizer.detokenizer - with wired_limit(model): + with wired_limit(model, [generation_stream]): tic = time.perf_counter() detokenizer.reset() for n, (token, logprobs) in zip( @@ -416,8 +367,7 @@ def generate( if formatter: # We have to finalize so that the prob corresponds to the last segment detokenizer.finalize() - with mx.stream(mx.cpu): - prob = mx.exp(logprobs[token]).item() + prob = mx.exp(logprobs[token]).item() formatter(detokenizer.last_segment, prob) else: print(detokenizer.last_segment, end="", flush=True) @@ -438,7 +388,7 @@ def generate( f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" ) print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 2**30 + peak_mem = mx.metal.get_peak_memory() / 1e9 print(f"Peak memory: {peak_mem:.3f} GB") return detokenizer.text @@ -623,7 +573,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): f""" # {upload_repo} - The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**. + The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was + converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) + using mlx-lm version **{__version__}**. ## Use with mlx diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index 68f1670b..e0a372a9 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase): "hello", max_tokens=5, verbose=False, - logits_processor=[logits_processor], + logits_processors=[logits_processor], ) self.assertEqual(len(all_toks), len(init_toks) + 5) diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 1e57bd86..0867ab56 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -299,7 +299,7 @@ class TestPromptCache(unittest.TestCase): ): i += 1 self.assertEqual(tok, toks[i]) - self.assertTrue(mx.allclose(logits, all_logits[i], rtol=1e-2)) + self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2)) if __name__ == "__main__":