mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-04 07:44:34 +08:00
Merge branch 'main' into feat/batch_generate
This commit is contained in:
@@ -18,7 +18,7 @@ from mlx.utils import tree_flatten
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models.base import KVCache, RotatingKVCache
|
||||
from .models import base, cache
|
||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
@@ -124,26 +124,6 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
|
||||
return logits
|
||||
|
||||
|
||||
def make_kv_caches(
|
||||
model: nn.Module, max_kv_size: Optional[int] = None
|
||||
) -> List[Union[KVCache, RotatingKVCache]]:
|
||||
if hasattr(model, "make_cache"):
|
||||
return model.make_cache()
|
||||
|
||||
kv_heads = (
|
||||
[model.n_kv_heads] * len(model.layers)
|
||||
if isinstance(model.n_kv_heads, int)
|
||||
else model.n_kv_heads
|
||||
)
|
||||
if max_kv_size is not None:
|
||||
return [
|
||||
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
|
||||
for n in kv_heads
|
||||
]
|
||||
else:
|
||||
return [KVCache(model.head_dim, n) for n in kv_heads]
|
||||
|
||||
|
||||
def generate_step(
|
||||
prompts: mx.array,
|
||||
model: nn.Module,
|
||||
@@ -155,7 +135,7 @@ def generate_step(
|
||||
min_tokens_to_keep: int = 1,
|
||||
prefill_step_size: int = 512,
|
||||
max_kv_size: Optional[int] = None,
|
||||
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
prompt_cache: Optional[Any] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
@@ -180,6 +160,8 @@ def generate_step(
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||
entries (except the first 4 tokens) will be overwritten.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
@@ -243,20 +225,13 @@ def generate_step(
|
||||
tokens = None
|
||||
|
||||
# Create the KV cache for generation
|
||||
cache = make_kv_caches(model, max_kv_size)
|
||||
|
||||
if cache_history is not None:
|
||||
if len(cache_history) != len(cache):
|
||||
raise ValueError("Wrong number of layers in the cache history")
|
||||
|
||||
# Set the history in the cache objects and evaluate them to prepare for
|
||||
# generation.
|
||||
for c, h in zip(cache, cache_history):
|
||||
c.update_and_fetch(h[0], h[1])
|
||||
mx.eval([c.state for c in cache])
|
||||
if prompt_cache is None:
|
||||
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
def _step(y):
|
||||
logits = model(y, cache=cache)
|
||||
logits = model(y, cache=prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if logits_processor:
|
||||
@@ -270,7 +245,7 @@ def generate_step(
|
||||
return y, logprobs
|
||||
|
||||
while y.shape[1] > prefill_step_size:
|
||||
model(y[:, :prefill_step_size], cache=cache)
|
||||
model(y[:, :prefill_step_size], cache=prompt_cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
y = y[:, prefill_step_size:]
|
||||
|
||||
@@ -312,9 +287,9 @@ def stream_generate(
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
detokenizer.reset()
|
||||
for (token, _), n in zip(
|
||||
generate_step(prompt_tokens[None], model, **kwargs),
|
||||
for _, (token, _) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
):
|
||||
token = token.item()
|
||||
if token == tokenizer.eos_token_id:
|
||||
@@ -365,9 +340,9 @@ def generate(
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
|
||||
for (token, logprobs), n in zip(
|
||||
generate_step(prompt_tokens[None], model, **kwargs),
|
||||
for n, (token, logprobs) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens[None], model, **kwargs),
|
||||
):
|
||||
token = token.item()
|
||||
if n == 0:
|
||||
|
Reference in New Issue
Block a user