mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Add support for logit bias (#697)
This commit is contained in:
parent
6abdbe3be8
commit
1484598de1
@ -128,6 +128,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
self.top_p = self.body.get("top_p", 1.0)
|
self.top_p = self.body.get("top_p", 1.0)
|
||||||
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
||||||
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
||||||
|
self.logit_bias = self.body.get("logit_bias", None)
|
||||||
|
|
||||||
# Get stop id sequences, if provided
|
# Get stop id sequences, if provided
|
||||||
stop_words = self.body.get("stop", [])
|
stop_words = self.body.get("stop", [])
|
||||||
@ -247,6 +248,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
top_p=self.top_p,
|
top_p=self.top_p,
|
||||||
repetition_penalty=self.repetition_penalty,
|
repetition_penalty=self.repetition_penalty,
|
||||||
repetition_context_size=self.repetition_context_size,
|
repetition_context_size=self.repetition_context_size,
|
||||||
|
logit_bias=self.logit_bias,
|
||||||
),
|
),
|
||||||
range(self.max_tokens),
|
range(self.max_tokens),
|
||||||
):
|
):
|
||||||
|
@ -117,6 +117,7 @@ def generate_step(
|
|||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
repetition_context_size: Optional[int] = 20,
|
repetition_context_size: Optional[int] = 20,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing text based on the given prompt from the model.
|
A generator producing text based on the given prompt from the model.
|
||||||
@ -135,6 +136,7 @@ def generate_step(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||||
|
logits = logits + logit_bias if logit_bias else logits
|
||||||
softmax_logits = mx.softmax(logits)
|
softmax_logits = mx.softmax(logits)
|
||||||
|
|
||||||
if temp == 0:
|
if temp == 0:
|
||||||
@ -203,6 +205,7 @@ def generate(
|
|||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
repetition_context_size: Optional[int] = None,
|
repetition_context_size: Optional[int] = None,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate text from the model.
|
Generate text from the model.
|
||||||
@ -241,6 +244,7 @@ def generate(
|
|||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
repetition_context_size,
|
repetition_context_size,
|
||||||
top_p,
|
top_p,
|
||||||
|
logit_bias,
|
||||||
),
|
),
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
):
|
):
|
||||||
|
Loading…
Reference in New Issue
Block a user