some cleanup, warnings, tests

This commit is contained in:
Awni Hannun 2024-11-22 16:51:22 -08:00
parent 9986787303
commit f82e49aad9
9 changed files with 57 additions and 61 deletions

View File

@ -100,8 +100,9 @@ To see a description of all the arguments you can do:
#### Streaming #### Streaming
For streaming generation, use the `stream_generate` function. This returns a For streaming generation, use the `stream_generate` function. This yields
generator object which streams the output text, token, and log probabilities. a generation response object.
For example, For example,
```python ```python

View File

@ -6,6 +6,7 @@ import json
import mlx.core as mx import mlx.core as mx
from .models.cache import make_prompt_cache from .models.cache import make_prompt_cache
from .sample_utils import make_sampler
from .utils import load, stream_generate from .utils import load, stream_generate
DEFAULT_TEMP = 0.0 DEFAULT_TEMP = 0.0
@ -79,8 +80,7 @@ def main():
tokenizer, tokenizer,
prompt, prompt,
args.max_tokens, args.max_tokens,
temp=args.temp, sampler=make_sampler(args.temp, args.top_p),
top_p=args.top_p,
prompt_cache=prompt_cache, prompt_cache=prompt_cache,
): ):
print(response.text, flush=True, end="") print(response.text, flush=True, end="")

View File

@ -42,7 +42,6 @@ response = generate(
tokenizer, tokenizer,
prompt=prompt, prompt=prompt,
verbose=True, verbose=True,
temp=0.0,
prompt_cache=prompt_cache, prompt_cache=prompt_cache,
) )

View File

@ -23,14 +23,6 @@ max_tokens = 1_000
# Specify if tokens and timing information will be printed # Specify if tokens and timing information will be printed
verbose = True verbose = True
# Some optional arguments for causal language model generation
generation_args = {
"temp": 0.7,
"repetition_penalty": 1.2,
"repetition_context_size": 20,
"top_p": 0.95,
}
# Generate a response with the specified settings # Generate a response with the specified settings
response = generate( response = generate(
model=model, model=model,
@ -38,5 +30,4 @@ response = generate(
prompt=prompt, prompt=prompt,
max_tokens=max_tokens, max_tokens=max_tokens,
verbose=verbose, verbose=verbose,
**generation_args,
) )

View File

@ -7,6 +7,7 @@ import sys
import mlx.core as mx import mlx.core as mx
from .models.cache import QuantizedKVCache, load_prompt_cache from .models.cache import QuantizedKVCache, load_prompt_cache
from .sample_utils import make_sampler
from .utils import generate, load from .utils import generate, load
DEFAULT_PROMPT = "hello" DEFAULT_PROMPT = "hello"
@ -218,16 +219,14 @@ def main():
else: else:
prompt = args.prompt prompt = args.prompt
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate( response = generate(
model, model,
tokenizer, tokenizer,
prompt, prompt,
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
verbose=args.verbose, verbose=args.verbose,
temp=args.temp, sampler=sampler,
top_p=args.top_p,
min_p=args.min_p,
min_tokens_to_keep=args.min_tokens_to_keep,
max_kv_size=args.max_kv_size, max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None, prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits, kv_bits=args.kv_bits,

View File

@ -27,6 +27,7 @@ from huggingface_hub import scan_cache_dir
from ._version import __version__ from ._version import __version__
from .models.cache import make_prompt_cache from .models.cache import make_prompt_cache
from .sample_utils import make_logits_processors, make_sampler
from .utils import load, stream_generate from .utils import load, stream_generate
@ -464,15 +465,17 @@ class APIHandler(BaseHTTPRequestHandler):
text = "" text = ""
tic = time.perf_counter() tic = time.perf_counter()
sampler = make_sampler(self.temperature)
logits_processors = make_logits_processors(
self.logit_bias, self.repetition_penalty, self.repetition_context_size
)
for gen_response in stream_generate( for gen_response in stream_generate(
model=self.model, model=self.model,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
prompt=prompt, prompt=prompt,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
temp=self.temperature, sampler=sampler,
repetition_penalty=self.repetition_penalty, logits_processors=logits_processors,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache, prompt_cache=self.prompt_cache.cache,
): ):
segment = gen_response.text segment = gen_response.text

View File

@ -73,16 +73,16 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
def reset(self): def reset(self):
self.offset = 0 self.offset = 0
self._tokens = [] self.tokens = []
self._text = "" self._text = ""
self._current_tokens = [] self._current_tokens = []
self._current_text = "" self._current_text = ""
def add_token(self, token): def add_token(self, token):
self._current_tokens.append(token) self._current_tokens.append(token)
self.tokens.append(token)
def finalize(self): def finalize(self):
self._tokens.extend(self._current_tokens)
self._text += self._tokenizer.decode(self._current_tokens) self._text += self._tokenizer.decode(self._current_tokens)
self._current_tokens = [] self._current_tokens = []
self._current_text = "" self._current_text = ""
@ -97,16 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
): ):
self._current_text = self._current_text[:-1] self._current_text = self._current_text[:-1]
if self._current_text and self._current_text[-1] == "\n": if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens)
self._text += self._current_text self._text += self._current_text
self._current_tokens.clear() self._current_tokens.clear()
self._current_text = "" self._current_text = ""
return self._text + self._current_text return self._text + self._current_text
@property
def tokens(self):
return self._tokens
class SPMStreamingDetokenizer(StreamingDetokenizer): class SPMStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for SPM models. """A streaming detokenizer for SPM models.
@ -143,6 +138,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
self.text += text self.text += text
def add_token(self, token): def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token] v = self.tokenmap[token]
if v.startswith(self._sep): if v.startswith(self._sep):
self._flush() self._flush()
@ -200,6 +196,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
return current_text return current_text
def add_token(self, token): def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token] v = self.tokenmap[token]
is_added = token in self._added_ids is_added = token in self._added_ids
if is_added or self._byte_decoder[v[0]] == 32: if is_added or self._byte_decoder[v[0]] == 32:

View File

@ -182,20 +182,21 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
def generate_step( def generate_step(
prompt: mx.array, prompt: mx.array,
model: nn.Module, model: nn.Module,
temp: float = 0.0, *,
repetition_penalty: Optional[float] = None, sampler: Optional[Callable[mx.array, mx.array]] = None,
repetition_context_size: Optional[int] = 20, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
top_p: float = 1.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None, max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None, prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None, prefill_step_size: int = 512,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
kv_bits: Optional[int] = None, kv_bits: Optional[int] = None,
kv_group_size: int = 64, kv_group_size: int = 64,
quantized_kv_start: int = 0, quantized_kv_start: int = 0,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
min_tokens_to_keep: Optional[int] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -203,32 +204,21 @@ def generate_step(
Args: Args:
prompt (mx.array): The input prompt. prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
prefill_step_size (int): Step size for processing the prompt. prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten. entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place. provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias. sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed A list of functions that take tokens and logits and return the processed
logits. Default: ``None``. logits. Default: ``None``.
kv_bits (int, optional): Number of bits to use for KV cache quantization. kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``. None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``. kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache. quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``. when ``kv_bits`` is non-None. Default: ``0``.
Yields: Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities. Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
@ -246,10 +236,22 @@ def generate_step(
elif len(prompt_cache) != len(model.layers): elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.") raise ValueError("Wrong number of layers in the prompt cache.")
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) if temp is not None or top_p is not None or min_tokens_to_keep is not None:
logits_processors = logits_processors or [] print(
logits_processors.extend( "[Warning] Specifying sampling arguments to ``generate_step`` is "
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size) "deprecated. Pass in a ``sampler`` instead."
)
if repetition_penalty is not None:
print(
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
"Pass in ``logits_processors`` instead."
)
sampler = sampler or make_sampler(
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
)
logits_processors = logits_processors or make_logits_processors(
None, repetition_penalty, repetition_context_size or 20
) )
def _step(y): def _step(y):
@ -385,7 +387,10 @@ def generate(
See :func:`stream_generate` for more details. See :func:`stream_generate` for more details.
""" """
if formatter is not None: if formatter is not None:
print("Text formatting is deprecated and will be removed in the next version.") print(
"[Warning] Text formatting is deprecated and no longer used. "
"The argument will be removed in a future version."
)
if verbose: if verbose:
print("=" * 10) print("=" * 10)
print("Prompt:", prompt) print("Prompt:", prompt)

View File

@ -34,10 +34,11 @@ class TestTokenizers(unittest.TestCase):
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
detokenizer.reset() detokenizer.reset()
text = "" text = ""
for t in tokens: for e, t in enumerate(tokens):
detokenizer.add_token(t) detokenizer.add_token(t)
seg = detokenizer.last_segment seg = detokenizer.last_segment
text += seg text += seg
self.assertEqual(detokenizer.tokens, tokens[: e + 1])
detokenizer.finalize() detokenizer.finalize()
text += detokenizer.last_segment text += detokenizer.last_segment
self.assertEqual(text, expected_text) self.assertEqual(text, expected_text)