mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Add support for logit bias (#697)
This commit is contained in:
@@ -117,6 +117,7 @@ def generate_step(
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = 20,
|
||||
top_p: float = 1.0,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing text based on the given prompt from the model.
|
||||
@@ -135,6 +136,7 @@ def generate_step(
|
||||
"""
|
||||
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||
logits = logits + logit_bias if logit_bias else logits
|
||||
softmax_logits = mx.softmax(logits)
|
||||
|
||||
if temp == 0:
|
||||
@@ -203,6 +205,7 @@ def generate(
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate text from the model.
|
||||
@@ -241,6 +244,7 @@ def generate(
|
||||
repetition_penalty,
|
||||
repetition_context_size,
|
||||
top_p,
|
||||
logit_bias,
|
||||
),
|
||||
range(max_tokens),
|
||||
):
|
||||
|
Reference in New Issue
Block a user