From 2ba0e3668382d2c18ab6f691e2f662081596269f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 12 Dec 2024 11:12:21 -0800 Subject: [PATCH] [mlx-lm] Use top p in server (#1144) * use top p in server * couple other fixes --- llms/mlx_lm/sample_utils.py | 2 +- llms/mlx_lm/server.py | 2 +- llms/mlx_lm/utils.py | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index f9868422..c77f056a 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -190,7 +190,7 @@ def make_repetition_penalty(penalty: float, context_size: int = 20): Callable[[mx.array, List[int]], mx.array]: The repetition penalty processor. """ - if penalty < 0 or not isinstance(penalty, float): + if penalty < 0 or not isinstance(penalty, (int, float)): raise ValueError(f"penalty must be a non-negative float, got {penalty}") def repetition_penalty_processor(tokens, logits): diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index ce09cf45..c12513ff 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -465,7 +465,7 @@ class APIHandler(BaseHTTPRequestHandler): text = "" tic = time.perf_counter() - sampler = make_sampler(self.temperature) + sampler = make_sampler(self.temperature, top_p=self.top_p) logits_processors = make_logits_processors( self.logit_bias, self.repetition_penalty, self.repetition_context_size ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 493c1c42..b87f5a24 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -299,6 +299,9 @@ def generate_step( prompt_processed_tokens = 0 while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) + maybe_quantize_kv_cache( + prompt_cache, quantized_kv_start, kv_group_size, kv_bits + ) mx.eval([c.state for c in prompt_cache]) prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) prompt_processed_tokens += prefill_step_size