mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
nits
This commit is contained in:
@@ -101,7 +101,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
||||
return model_path
|
||||
|
||||
|
||||
def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
|
||||
def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
|
||||
"""
|
||||
Apply repetition penalty to specific logits based on the given context.
|
||||
|
||||
@@ -109,19 +109,18 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
|
||||
|
||||
Args:
|
||||
logits (mx.array): The logits produced by the language model.
|
||||
generated_tokens (any): A list of N previous tokens.
|
||||
tokens (mx.array): A list of N previous tokens.
|
||||
penalty (float): The repetition penalty factor to be applied.
|
||||
|
||||
Returns:
|
||||
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
||||
"""
|
||||
if len(generated_tokens) > 0:
|
||||
indices = mx.array([token for token in generated_tokens])
|
||||
selected_logits = logits[:, indices]
|
||||
if len(tokens) > 0:
|
||||
selected_logits = logits[:, tokens]
|
||||
selected_logits = mx.where(
|
||||
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
||||
)
|
||||
logits[:, indices] = selected_logits
|
||||
logits[:, tokens] = selected_logits
|
||||
return logits
|
||||
|
||||
|
||||
@@ -158,7 +157,7 @@ def generate_step(
|
||||
max_kv_size: Optional[int] = None,
|
||||
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = [],
|
||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
@@ -184,7 +183,7 @@ def generate_step(
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``[]``.
|
||||
logits. Default: ``None``.
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||
@@ -212,18 +211,26 @@ def generate_step(
|
||||
raise ValueError(
|
||||
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
||||
)
|
||||
|
||||
|
||||
logits_processor = logits_processor or []
|
||||
|
||||
if repetition_penalty:
|
||||
def repetition_penalty_processor(tokens: mx.array, logits: mx.array) -> mx.array:
|
||||
return apply_repetition_penalty(logits, tokens[-repetition_context_size:], repetition_penalty)
|
||||
|
||||
def repetition_penalty_processor(tokens, logits):
|
||||
return apply_repetition_penalty(
|
||||
logits, tokens[-repetition_context_size:], repetition_penalty
|
||||
)
|
||||
|
||||
logits_processor.append(repetition_penalty_processor)
|
||||
|
||||
if logit_bias:
|
||||
def logit_bias_processor(_: mx.array, logits: mx.array) -> mx.array:
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
|
||||
def logit_bias_processor(_, logits):
|
||||
logits[:, indices] += values
|
||||
return logits
|
||||
|
||||
logits_processor.append(logit_bias_processor)
|
||||
|
||||
y = prompt
|
||||
@@ -249,7 +256,7 @@ def generate_step(
|
||||
if logits_processor:
|
||||
nonlocal tokens
|
||||
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
||||
|
||||
|
||||
for processor in logits_processor:
|
||||
logits = processor(tokens, logits)
|
||||
|
||||
|
Reference in New Issue
Block a user