diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 57f080e2..eab5e792 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -8,6 +8,7 @@ DEFAULT_MODEL_PATH = "mlx_model" DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 DEFAULT_TEMP = 0.6 +DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 @@ -44,6 +45,9 @@ def setup_arg_parser(): parser.add_argument( "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" ) + parser.add_argument( + "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" + ) parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") parser.add_argument( "--ignore-chat-template", @@ -109,7 +113,7 @@ def main(args): formatter = colorprint_by_t0 if args.colorize else None generate( - model, tokenizer, prompt, args.temp, args.max_tokens, True, formatter=formatter + model, tokenizer, prompt, args.temp, args.max_tokens, True, formatter=formatter, top_p=args.top_p ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index e10a8a08..814beca4 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -111,6 +111,7 @@ def generate_step( temp: 0.0, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = 20, + top_p: float = 1.0, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing text based on the given prompt from the model. @@ -133,7 +134,26 @@ def generate_step( if temp == 0: token = mx.argmax(logits, axis=-1) else: - token = mx.random.categorical(logits * (1 / temp)) + if top_p > 0 and top_p < 1.0: + if ( + logits.dtype == mx.bfloat16 + ): # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16 + logits = logits.astype(mx.float32) + probs = mx.softmax(logits / temp, axis=-1) + + sorted_probs = mx.sort(probs)[::-1] + sorted_indices = mx.argsort(probs)[::-1] + cumulative_probs = mx.cumsum(sorted_probs, axis=-1) + + top_probs = mx.where( + cumulative_probs > 1 - top_p, + sorted_probs, + mx.zeros_like(sorted_probs), + ) + sorted_token = mx.random.categorical(mx.log(top_probs)) + token = sorted_indices.squeeze(0)[sorted_token] + else: + token = mx.random.categorical(logits * (1 / temp)) prob = softmax_logits[0, token] return token, prob @@ -182,6 +202,7 @@ def generate( formatter: Callable = None, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = None, + top_p: float = 1.0, ) -> str: """ Generate text from the model. @@ -218,6 +239,7 @@ def generate( temp, repetition_penalty, repetition_context_size, + top_p, ), range(max_tokens), ):