From 2277033a24d4890847c3032b668a195dff9bd1b7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 8 Dec 2024 20:16:44 -0800 Subject: [PATCH] use top p in server --- llms/mlx_lm/sample_utils.py | 2 +- llms/mlx_lm/server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index f9868422..c77f056a 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -190,7 +190,7 @@ def make_repetition_penalty(penalty: float, context_size: int = 20): Callable[[mx.array, List[int]], mx.array]: The repetition penalty processor. """ - if penalty < 0 or not isinstance(penalty, float): + if penalty < 0 or not isinstance(penalty, (int, float)): raise ValueError(f"penalty must be a non-negative float, got {penalty}") def repetition_penalty_processor(tokens, logits): diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index ce09cf45..c12513ff 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -465,7 +465,7 @@ class APIHandler(BaseHTTPRequestHandler): text = "" tic = time.perf_counter() - sampler = make_sampler(self.temperature) + sampler = make_sampler(self.temperature, top_p=self.top_p) logits_processors = make_logits_processors( self.logit_bias, self.repetition_penalty, self.repetition_context_size )