# Copyright © 2023-2024 Apple Inc. import math from functools import partial from typing import Callable, Dict, Optional import mlx.core as mx def make_sampler( temp: float = 0.0, top_p: float = 0.0, min_p: float = 0.0, min_tokens_to_keep: int = 1, top_k: int = -1, ) -> Callable[mx.array, mx.array]: """ Make a sampler function for use with ``generate_step``. Args: temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``. top_p (float, optional): Nulceus sampling, higher means model considers more less likely words. min_p (float, optional): The minimum value (scaled by the top token's probability) that a token probability must have to be considered. min_tokens_to_keep (int, optional): Minimum number of tokens that cannot be filtered by min_p sampling. top_k (int, optional): The top k tokens ranked by probability to constrain the sampling to. Returns: Callable[mx.array, mx.array]: A sampler which takes log-probabilities and returns tokens. """ if temp == 0: return lambda x: mx.argmax(x, axis=-1) elif top_p > 0 and top_p < 1.0: return lambda x: top_p_sampling(x, top_p, temp) elif min_p != 0.0: return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp) elif top_k > 0: return lambda x: top_k_sampling(x, top_k, temp) else: return lambda x: categorical_sampling(x, temp) def make_logits_processors( logit_bias: Optional[Dict[int, float]] = None, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = 20, ): """ Make logits processors for use with ``generate_step``. Args: repetition_penalty (float, optional): The penalty factor for repeating tokens. repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. Default: ``20``. logit_bias (dictionary, optional): Additive logit bias. Returns: List[Callable[[mx.array, mx.array], mx.array]]: A list of logits processors. Each processor in the list is a callable which takes an array of tokens and an array of logits and returns the updated logits. """ logits_processors = [] if logit_bias: indices = mx.array(list(logit_bias.keys())) values = mx.array(list(logit_bias.values())) def logit_bias_processor(_, logits): logits[:, indices] += values return logits logits_processors.append(logit_bias_processor) if repetition_penalty and repetition_penalty != 0.0: logits_processors.append( make_repetition_penalty(repetition_penalty, repetition_context_size) ) return logits_processors @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def top_k_sampling( logprobs: mx.array, top_k: int, ) -> mx.array: """ Sample from only the top K tokens ranked by probability. Args: logprobs: A vector of log probabilities. top_k (int): Top k tokens to sample from. """ vocab_size = logprobs.shape[-1] if not isinstance(top_k, int) or not (0 < top_k < vocab_size): raise ValueError( f"`top_k` has to be an integer in the (0, {vocab_size}] interval," f" but is {top_k}." ) mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:] masked_logprobs = mx.put_along_axis( logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1 ) return masked_logprobs @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def min_p_sampling( logprobs: mx.array, min_p: float, min_tokens_to_keep: int = 1, ) -> mx.array: """ Apply min-p sampling to the logprobs. Min-p keeps all tokens that are above a minimum probability, scaled by the probability of the most likely token. As a result, the filter is more aggressive given a very high-probability token. Args: logprobs: A vector of log probabilities. min_p (float): Minimum token probability. Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in the 0.99-0.8 range. min_tokens_to_keep (int, optional): Minimum number of tokens that cannot be filtered. Default: ``1``. """ if not (0 <= min_p <= 1.0): raise ValueError( f"`min_p` has to be a float in the [0, 1] interval, but is {min_p}" ) if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1): raise ValueError( f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}" ) # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 # Indices sorted in decreasing order sorted_indices = mx.argsort(-logprobs, axis=-1) sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1) # Top probability top_logprobs = sorted_logprobs[:, 0:1] # Calculate the min_p threshold scaled_min_p = top_logprobs + math.log(min_p) # Mask tokens that have a probability less than the scaled min_p tokens_to_remove = sorted_logprobs < scaled_min_p tokens_to_remove[..., :min_tokens_to_keep] = False # Create pool of tokens with probability less than scaled min_p selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs) # Create a mapping to rearrange back to original indices # Use argsort of sorted_indices to get the inverse permutation inverse_indices = mx.argsort(sorted_indices, axis=-1) # Rearrange selected_logprobs back to original order original_order_logprobs = mx.take_along_axis( selected_logprobs, inverse_indices, axis=-1 ) return original_order_logprobs @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def top_p_sampling(logits: mx.array, top_p: float) -> mx.array: """ Apply top-p (nucleus) sampling to logits. Args: logits: The logits from the model's output. top_p: The cumulative probability threshold for top-p filtering. Returns: token selected based on the top-p criterion. """ # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 probs = mx.softmax(logits, axis=-1) # sort probs in ascending order sorted_indices = mx.argsort(probs, axis=-1) sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1) cumulative_probs = mx.cumsum(sorted_probs, axis=-1) # select tokens with cumulative probs below threshold top_probs = mx.where( cumulative_probs > 1 - top_p, sorted_probs, 0, ) # Create a mapping to rearrange back to original indices # Use argsort of sorted_indices to get the inverse permutation inverse_indices = mx.argsort(sorted_indices, axis=-1) # Rearrange top_probs back to original order original_order_probs = mx.take_along_axis(top_probs, inverse_indices, axis=-1) # Convert back to logits and return return mx.log(mx.where(original_order_probs > 0, original_order_probs, 0)) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def categorical_sampling(logits, temp): return mx.random.categorical(logits * (1 / temp)) def make_repetition_penalty(penalty: float, context_size: int = 20): """ Make repetition penalty processor. Paper: https://arxiv.org/abs/1909.05858 Args: penalty (float): The repetition penalty factor to be applied. context_size (int): The number of previous tokens to use. Default: ``20``. Returns: Callable[[mx.array, List[int]], mx.array]: The repetition penalty processor. """ if penalty < 0 or not isinstance(penalty, (int, float)): raise ValueError(f"penalty must be a non-negative float, got {penalty}") def repetition_penalty_processor(tokens, logits): if len(tokens) > 0: tokens = tokens[-context_size:] selected_logits = logits[:, tokens] selected_logits = mx.where( selected_logits < 0, selected_logits * penalty, selected_logits / penalty, ) logits[:, tokens] = selected_logits return logits return repetition_penalty_processor