Validate server params & fix logit bias bug (#731)

* Bug fix in logit bias

* Add parameter validations

* Fix typo

* Update docstrings to match MLX styling

* Black style + fix a validation bug
This commit is contained in:
Karim Elmaaroufi
2024-04-30 07:27:40 -07:00
committed by GitHub
parent 7c0962f4e2
commit 4bf2eb17f2
3 changed files with 55 additions and 8 deletions

View File

@@ -136,7 +136,10 @@ def generate_step(
"""
def sample(logits: mx.array) -> Tuple[mx.array, float]:
logits = logits + logit_bias if logit_bias else logits
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
logits[:, indices] += values
softmax_logits = mx.softmax(logits)
if temp == 0: