Merge branch 'main' into feat/batch_generate

This commit is contained in:
L Lllvvuu
2024-10-09 15:19:22 -04:00
48 changed files with 1152 additions and 802 deletions

View File

@@ -18,7 +18,7 @@ from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
# Local imports
from .models.base import KVCache, RotatingKVCache
from .models import base, cache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model
@@ -124,26 +124,6 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
return logits
def make_kv_caches(
model: nn.Module, max_kv_size: Optional[int] = None
) -> List[Union[KVCache, RotatingKVCache]]:
if hasattr(model, "make_cache"):
return model.make_cache()
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
if max_kv_size is not None:
return [
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
for n in kv_heads
]
else:
return [KVCache(model.head_dim, n) for n in kv_heads]
def generate_step(
prompts: mx.array,
model: nn.Module,
@@ -155,7 +135,7 @@ def generate_step(
min_tokens_to_keep: int = 1,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
@@ -180,6 +160,8 @@ def generate_step(
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias.
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
@@ -243,20 +225,13 @@ def generate_step(
tokens = None
# Create the KV cache for generation
cache = make_kv_caches(model, max_kv_size)
if cache_history is not None:
if len(cache_history) != len(cache):
raise ValueError("Wrong number of layers in the cache history")
# Set the history in the cache objects and evaluate them to prepare for
# generation.
for c, h in zip(cache, cache_history):
c.update_and_fetch(h[0], h[1])
mx.eval([c.state for c in cache])
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
def _step(y):
logits = model(y, cache=cache)
logits = model(y, cache=prompt_cache)
logits = logits[:, -1, :]
if logits_processor:
@@ -270,7 +245,7 @@ def generate_step(
return y, logprobs
while y.shape[1] > prefill_step_size:
model(y[:, :prefill_step_size], cache=cache)
model(y[:, :prefill_step_size], cache=prompt_cache)
mx.eval([c.state for c in cache])
y = y[:, prefill_step_size:]
@@ -312,9 +287,9 @@ def stream_generate(
detokenizer = tokenizer.detokenizer
detokenizer.reset()
for (token, _), n in zip(
generate_step(prompt_tokens[None], model, **kwargs),
for _, (token, _) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
token = token.item()
if token == tokenizer.eos_token_id:
@@ -365,9 +340,9 @@ def generate(
tic = time.perf_counter()
detokenizer.reset()
for (token, logprobs), n in zip(
generate_step(prompt_tokens[None], model, **kwargs),
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens[None], model, **kwargs),
):
token = token.item()
if n == 0: