diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index f9868422..c77f056a 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -190,7 +190,7 @@ def make_repetition_penalty(penalty: float, context_size: int = 20): Callable[[mx.array, List[int]], mx.array]: The repetition penalty processor. """ - if penalty < 0 or not isinstance(penalty, float): + if penalty < 0 or not isinstance(penalty, (int, float)): raise ValueError(f"penalty must be a non-negative float, got {penalty}") def repetition_penalty_processor(tokens, logits): diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index ce09cf45..c12513ff 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -465,7 +465,7 @@ class APIHandler(BaseHTTPRequestHandler): text = "" tic = time.perf_counter() - sampler = make_sampler(self.temperature) + sampler = make_sampler(self.temperature, top_p=self.top_p) logits_processors = make_logits_processors( self.logit_bias, self.repetition_penalty, self.repetition_context_size ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 493c1c42..b87f5a24 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -299,6 +299,9 @@ def generate_step( prompt_processed_tokens = 0 while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) + maybe_quantize_kv_cache( + prompt_cache, quantized_kv_start, kv_group_size, kv_bits + ) mx.eval([c.state for c in prompt_cache]) prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) prompt_processed_tokens += prefill_step_size