mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Add logits_processor option to generate_step function (#983)
* Add logits_processor option for the generation as in huggingface transformers library * concatenation correction * Rename the tokens variable for clarity * remove the logit_bias argument from generate_step method * fix the variable name * nits + test * test * add back logit bias + test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
d812516d3d
commit
ace2bb5890
@ -154,10 +154,11 @@ def generate_step(
|
|||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
min_p: float = 0.0,
|
min_p: float = 0.0,
|
||||||
min_tokens_to_keep: int = 1,
|
min_tokens_to_keep: int = 1,
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
|
||||||
prefill_step_size: int = 512,
|
prefill_step_size: int = 512,
|
||||||
max_kv_size: Optional[int] = None,
|
max_kv_size: Optional[int] = None,
|
||||||
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||||
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
|
logits_processor: Optional[Callable[[mx.array, mx.array], mx.array]] = None,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
@ -177,10 +178,13 @@ def generate_step(
|
|||||||
probability) that a token probability must have to be considered.
|
probability) that a token probability must have to be considered.
|
||||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||||
be filtered by min_p sampling.
|
be filtered by min_p sampling.
|
||||||
logit_bias (dictionary, optional): Additive logit bias.
|
|
||||||
prefill_step_size (int): Step size for processing the prompt.
|
prefill_step_size (int): Step size for processing the prompt.
|
||||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||||
entries (except the first 4 tokens) will be overwritten.
|
entries (except the first 4 tokens) will be overwritten.
|
||||||
|
logit_bias (dictionary, optional): Additive logit bias.
|
||||||
|
logits_processor (Callable[[mx.array, mx.array], mx.array], optional):
|
||||||
|
A function that takes tokens and logits and returns the processed
|
||||||
|
logits. Default: ``None``.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||||
@ -188,10 +192,6 @@ def generate_step(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||||
if logit_bias:
|
|
||||||
indices = mx.array(list(logit_bias.keys()))
|
|
||||||
values = mx.array(list(logit_bias.values()))
|
|
||||||
logits[:, indices] += values
|
|
||||||
logprobs = logits - mx.logsumexp(logits)
|
logprobs = logits - mx.logsumexp(logits)
|
||||||
|
|
||||||
if temp == 0:
|
if temp == 0:
|
||||||
@ -214,6 +214,7 @@ def generate_step(
|
|||||||
)
|
)
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
|
tokens = None
|
||||||
|
|
||||||
# Create the KV cache for generation
|
# Create the KV cache for generation
|
||||||
cache = make_kv_caches(model, max_kv_size)
|
cache = make_kv_caches(model, max_kv_size)
|
||||||
@ -233,11 +234,23 @@ def generate_step(
|
|||||||
if repetition_context_size:
|
if repetition_context_size:
|
||||||
repetition_context = repetition_context[-repetition_context_size:]
|
repetition_context = repetition_context[-repetition_context_size:]
|
||||||
|
|
||||||
|
if logit_bias:
|
||||||
|
indices = mx.array(list(logit_bias.keys()))
|
||||||
|
values = mx.array(list(logit_bias.values()))
|
||||||
|
|
||||||
def _step(y):
|
def _step(y):
|
||||||
nonlocal repetition_context
|
nonlocal repetition_context
|
||||||
logits = model(y[None], cache=cache)
|
logits = model(y[None], cache=cache)
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
|
if logits_processor:
|
||||||
|
nonlocal tokens
|
||||||
|
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
||||||
|
logits = logits_processor(tokens, logits)
|
||||||
|
|
||||||
|
if logit_bias:
|
||||||
|
logits[:, indices] += values
|
||||||
|
|
||||||
if repetition_penalty:
|
if repetition_penalty:
|
||||||
logits = apply_repetition_penalty(
|
logits = apply_repetition_penalty(
|
||||||
logits, repetition_context, repetition_penalty
|
logits, repetition_context, repetition_penalty
|
||||||
|
55
llms/tests/test_generate.py
Normal file
55
llms/tests/test_generate.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from mlx_lm.utils import generate, load
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerate(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
|
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
|
||||||
|
|
||||||
|
def test_generate(self):
|
||||||
|
# Simple test that generation runs
|
||||||
|
text = generate(
|
||||||
|
self.model, self.tokenizer, "hello", max_tokens=5, verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_generate_with_logit_bias(self):
|
||||||
|
logit_bias = {0: 2000.0, 1: -20.0}
|
||||||
|
text = generate(
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
"hello",
|
||||||
|
max_tokens=5,
|
||||||
|
verbose=False,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
)
|
||||||
|
self.assertEqual(text, "!!!!!")
|
||||||
|
|
||||||
|
def test_generate_with_processor(self):
|
||||||
|
init_toks = self.tokenizer.encode("hello")
|
||||||
|
|
||||||
|
all_toks = None
|
||||||
|
|
||||||
|
def logits_processor(toks, logits):
|
||||||
|
nonlocal all_toks
|
||||||
|
all_toks = toks
|
||||||
|
return logits
|
||||||
|
|
||||||
|
generate(
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
"hello",
|
||||||
|
max_tokens=5,
|
||||||
|
verbose=False,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user