Add support for logit bias (#697)

This commit is contained in:
Karim Elmaaroufi
2024-04-21 06:53:56 -07:00
committed by GitHub
parent 6abdbe3be8
commit 1484598de1
2 changed files with 6 additions and 0 deletions

View File

@@ -117,6 +117,7 @@ def generate_step(
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
top_p: float = 1.0,
logit_bias: Optional[Dict[int, float]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing text based on the given prompt from the model.
@@ -135,6 +136,7 @@ def generate_step(
"""
def sample(logits: mx.array) -> Tuple[mx.array, float]:
logits = logits + logit_bias if logit_bias else logits
softmax_logits = mx.softmax(logits)
if temp == 0:
@@ -203,6 +205,7 @@ def generate(
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: float = 1.0,
logit_bias: Optional[Dict[int, float]] = None,
) -> str:
"""
Generate text from the model.
@@ -241,6 +244,7 @@ def generate(
repetition_penalty,
repetition_context_size,
top_p,
logit_bias,
),
range(max_tokens),
):