mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 09:48:54 +08:00
Validate server params & fix logit bias bug (#731)
* Bug fix in logit bias * Add parameter validations * Fix typo * Update docstrings to match MLX styling * Black style + fix a validation bug
This commit is contained in:
@@ -136,7 +136,10 @@ def generate_step(
|
||||
"""
|
||||
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||
logits = logits + logit_bias if logit_bias else logits
|
||||
if logit_bias:
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
logits[:, indices] += values
|
||||
softmax_logits = mx.softmax(logits)
|
||||
|
||||
if temp == 0:
|
||||
|
||||
Reference in New Issue
Block a user