diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index f4112b6a..3afc8b85 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -5,7 +5,7 @@ import json import logging import time from pathlib import Path -from typing import Any, Callable, Dict, Generator, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -80,10 +80,36 @@ def get_model_path(path_or_hf_repo: str) -> Path: return model_path +def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float): + """ + Apply repetition penalty to specific logits based on the given context. + + Paper: https://arxiv.org/abs/1909.05858 + + Args: + logits (mx.array): The logits produced by the language model. + generated_tokens (any): A list of N previous tokens. + penalty (float): The repetition penalty factor to be applied. + + Returns: + logits (mx.array): Logits with repetition penalty applied to generated tokens. + """ + if len(generated_tokens) > 0: + indices = mx.array([token for token in generated_tokens]) + selected_logits = logits[:, indices] + selected_logits = mx.where( + selected_logits < 0, selected_logits * penalty, selected_logits / penalty + ) + logits[:, indices] = selected_logits + return logits + + def generate_step( prompt: mx.array, model: nn.Module, - temp: float = 0.0, + temp: 0.0, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = 20, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing text based on the given prompt from the model. @@ -92,6 +118,9 @@ def generate_step( prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. temp (float): The temperature for sampling, if 0 the argmax is used. + repetition_penalty (float, optional): The penalty factor for repeating tokens. + repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20). + Yields: Generator[Tuple[mx.array, mx.array]]: A generator producing one token and probability per call. @@ -108,12 +137,37 @@ def generate_step( prob = softmax_logits[0, token] return token, prob + if repetition_penalty and ( + repetition_penalty < 0 or not isinstance(repetition_penalty, float) + ): + raise ValueError( + f"repetition_penalty must be a non-negative float, got {repetition_penalty}" + ) + y = prompt cache = None + + repetition_context = prompt.tolist() + + if repetition_context_size: + repetition_context = repetition_context[-repetition_context_size:] + while True: logits, cache = model(y[None], cache=cache) logits = logits[:, -1, :] - y, prob = sample(logits) + + if repetition_penalty: + logits = apply_repetition_penalty( + logits, repetition_context, repetition_penalty + ) + y, prob = sample(logits) + repetition_context.append(y.item()) + else: + y, prob = sample(logits) + + if repetition_context_size: + if len(repetition_context) > repetition_context_size: + repetition_context = repetition_context[-repetition_context_size:] yield y, prob @@ -125,6 +179,8 @@ def generate( max_tokens: int = 100, verbose: bool = False, formatter: Callable = None, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = None, ) -> str: """ Generate text from the model. @@ -139,20 +195,31 @@ def generate( (default ``False``). formatter (Optional[Callable]): A function which takes a token and a probability and displays it. + repetition_penalty (float, optional): The penalty factor for repeating tokens. + repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. """ if verbose: print("=" * 10) print("Prompt:", prompt) - prompt = mx.array(tokenizer.encode(prompt)) + prompt_tokens = mx.array(tokenizer.encode(prompt)) tic = time.perf_counter() tokens = [] skip = 0 REPLACEMENT_CHAR = "\ufffd" - for (token, prob), n in zip(generate_step(prompt, model, temp), range(max_tokens)): + for (token, prob), n in zip( + generate_step( + prompt_tokens, + model, + temp, + repetition_penalty, + repetition_context_size, + ), + range(max_tokens), + ): if token == tokenizer.eos_token_id: break if n == 0: @@ -179,7 +246,7 @@ def generate( if token_count == 0: print("No tokens generated for this prompt") return - prompt_tps = prompt.size / prompt_time + prompt_tps = prompt_tokens.size / prompt_time gen_tps = (token_count - 1) / gen_time print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") print(f"Generation: {gen_tps:.3f} tokens-per-sec")