From 1484598de1f751c26c2166f9f0e38fb1f7f370b3 Mon Sep 17 00:00:00 2001 From: Karim Elmaaroufi Date: Sun, 21 Apr 2024 06:53:56 -0700 Subject: [PATCH] Add support for logit bias (#697) --- llms/mlx_lm/server.py | 2 ++ llms/mlx_lm/utils.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index e765bb5e..0b5850a6 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -128,6 +128,7 @@ class APIHandler(BaseHTTPRequestHandler): self.top_p = self.body.get("top_p", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0) self.repetition_context_size = self.body.get("repetition_context_size", 20) + self.logit_bias = self.body.get("logit_bias", None) # Get stop id sequences, if provided stop_words = self.body.get("stop", []) @@ -247,6 +248,7 @@ class APIHandler(BaseHTTPRequestHandler): top_p=self.top_p, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, + logit_bias=self.logit_bias, ), range(self.max_tokens), ): diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index e38a0277..2d0767c7 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -117,6 +117,7 @@ def generate_step( repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = 20, top_p: float = 1.0, + logit_bias: Optional[Dict[int, float]] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing text based on the given prompt from the model. @@ -135,6 +136,7 @@ def generate_step( """ def sample(logits: mx.array) -> Tuple[mx.array, float]: + logits = logits + logit_bias if logit_bias else logits softmax_logits = mx.softmax(logits) if temp == 0: @@ -203,6 +205,7 @@ def generate( repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = None, top_p: float = 1.0, + logit_bias: Optional[Dict[int, float]] = None, ) -> str: """ Generate text from the model. @@ -241,6 +244,7 @@ def generate( repetition_penalty, repetition_context_size, top_p, + logit_bias, ), range(max_tokens), ):