some cleanup, warnings, tests

This commit is contained in:
Awni Hannun 2024-11-22 16:51:22 -08:00
parent 9986787303
commit f82e49aad9
9 changed files with 57 additions and 61 deletions

View File

@ -100,8 +100,9 @@ To see a description of all the arguments you can do:
#### Streaming
For streaming generation, use the `stream_generate` function. This returns a
generator object which streams the output text, token, and log probabilities.
For streaming generation, use the `stream_generate` function. This yields
a generation response object.
For example,
```python

View File

@ -6,6 +6,7 @@ import json
import mlx.core as mx
from .models.cache import make_prompt_cache
from .sample_utils import make_sampler
from .utils import load, stream_generate
DEFAULT_TEMP = 0.0
@ -79,8 +80,7 @@ def main():
tokenizer,
prompt,
args.max_tokens,
temp=args.temp,
top_p=args.top_p,
sampler=make_sampler(args.temp, args.top_p),
prompt_cache=prompt_cache,
):
print(response.text, flush=True, end="")

View File

@ -42,7 +42,6 @@ response = generate(
tokenizer,
prompt=prompt,
verbose=True,
temp=0.0,
prompt_cache=prompt_cache,
)

View File

@ -23,14 +23,6 @@ max_tokens = 1_000
# Specify if tokens and timing information will be printed
verbose = True
# Some optional arguments for causal language model generation
generation_args = {
"temp": 0.7,
"repetition_penalty": 1.2,
"repetition_context_size": 20,
"top_p": 0.95,
}
# Generate a response with the specified settings
response = generate(
model=model,
@ -38,5 +30,4 @@ response = generate(
prompt=prompt,
max_tokens=max_tokens,
verbose=verbose,
**generation_args,
)

View File

@ -7,6 +7,7 @@ import sys
import mlx.core as mx
from .models.cache import QuantizedKVCache, load_prompt_cache
from .sample_utils import make_sampler
from .utils import generate, load
DEFAULT_PROMPT = "hello"
@ -218,16 +219,14 @@ def main():
else:
prompt = args.prompt
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate(
model,
tokenizer,
prompt,
max_tokens=args.max_tokens,
verbose=args.verbose,
temp=args.temp,
top_p=args.top_p,
min_p=args.min_p,
min_tokens_to_keep=args.min_tokens_to_keep,
sampler=sampler,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits,

View File

@ -27,6 +27,7 @@ from huggingface_hub import scan_cache_dir
from ._version import __version__
from .models.cache import make_prompt_cache
from .sample_utils import make_logits_processors, make_sampler
from .utils import load, stream_generate
@ -464,15 +465,17 @@ class APIHandler(BaseHTTPRequestHandler):
text = ""
tic = time.perf_counter()
sampler = make_sampler(self.temperature)
logits_processors = make_logits_processors(
self.logit_bias, self.repetition_penalty, self.repetition_context_size
)
for gen_response in stream_generate(
model=self.model,
tokenizer=self.tokenizer,
prompt=prompt,
max_tokens=self.max_tokens,
temp=self.temperature,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=self.prompt_cache.cache,
):
segment = gen_response.text

View File

@ -73,16 +73,16 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
def reset(self):
self.offset = 0
self._tokens = []
self.tokens = []
self._text = ""
self._current_tokens = []
self._current_text = ""
def add_token(self, token):
self._current_tokens.append(token)
self.tokens.append(token)
def finalize(self):
self._tokens.extend(self._current_tokens)
self._text += self._tokenizer.decode(self._current_tokens)
self._current_tokens = []
self._current_text = ""
@ -97,16 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
):
self._current_text = self._current_text[:-1]
if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens)
self._text += self._current_text
self._current_tokens.clear()
self._current_text = ""
return self._text + self._current_text
@property
def tokens(self):
return self._tokens
class SPMStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for SPM models.
@ -143,6 +138,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
self.text += text
def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token]
if v.startswith(self._sep):
self._flush()
@ -200,6 +196,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
return current_text
def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token]
is_added = token in self._added_ids
if is_added or self._byte_decoder[v[0]] == 32:

View File

@ -182,20 +182,21 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
def generate_step(
prompt: mx.array,
model: nn.Module,
temp: float = 0.0,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
top_p: float = 1.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
prefill_step_size: int = 512,
*,
sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
prefill_step_size: int = 512,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
min_tokens_to_keep: Optional[int] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@ -203,32 +204,21 @@ def generate_step(
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
@ -246,10 +236,22 @@ def generate_step(
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
logits_processors = logits_processors or []
logits_processors.extend(
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
print(
"[Warning] Specifying sampling arguments to ``generate_step`` is "
"deprecated. Pass in a ``sampler`` instead."
)
if repetition_penalty is not None:
print(
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
"Pass in ``logits_processors`` instead."
)
sampler = sampler or make_sampler(
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
)
logits_processors = logits_processors or make_logits_processors(
None, repetition_penalty, repetition_context_size or 20
)
def _step(y):
@ -385,7 +387,10 @@ def generate(
See :func:`stream_generate` for more details.
"""
if formatter is not None:
print("Text formatting is deprecated and will be removed in the next version.")
print(
"[Warning] Text formatting is deprecated and no longer used. "
"The argument will be removed in a future version."
)
if verbose:
print("=" * 10)
print("Prompt:", prompt)

View File

@ -34,10 +34,11 @@ class TestTokenizers(unittest.TestCase):
detokenizer = tokenizer.detokenizer
detokenizer.reset()
text = ""
for t in tokens:
for e, t in enumerate(tokens):
detokenizer.add_token(t)
seg = detokenizer.last_segment
text += seg
self.assertEqual(detokenizer.tokens, tokens[: e + 1])
detokenizer.finalize()
text += detokenizer.last_segment
self.assertEqual(text, expected_text)