mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
some cleanup, warnings, tests
This commit is contained in:
parent
9986787303
commit
f82e49aad9
@ -100,8 +100,9 @@ To see a description of all the arguments you can do:
|
||||
|
||||
#### Streaming
|
||||
|
||||
For streaming generation, use the `stream_generate` function. This returns a
|
||||
generator object which streams the output text, token, and log probabilities.
|
||||
For streaming generation, use the `stream_generate` function. This yields
|
||||
a generation response object.
|
||||
|
||||
For example,
|
||||
|
||||
```python
|
||||
|
@ -6,6 +6,7 @@ import json
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import make_prompt_cache
|
||||
from .sample_utils import make_sampler
|
||||
from .utils import load, stream_generate
|
||||
|
||||
DEFAULT_TEMP = 0.0
|
||||
@ -79,8 +80,7 @@ def main():
|
||||
tokenizer,
|
||||
prompt,
|
||||
args.max_tokens,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
sampler=make_sampler(args.temp, args.top_p),
|
||||
prompt_cache=prompt_cache,
|
||||
):
|
||||
print(response.text, flush=True, end="")
|
||||
|
@ -42,7 +42,6 @@ response = generate(
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
verbose=True,
|
||||
temp=0.0,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
||||
|
||||
|
@ -23,14 +23,6 @@ max_tokens = 1_000
|
||||
# Specify if tokens and timing information will be printed
|
||||
verbose = True
|
||||
|
||||
# Some optional arguments for causal language model generation
|
||||
generation_args = {
|
||||
"temp": 0.7,
|
||||
"repetition_penalty": 1.2,
|
||||
"repetition_context_size": 20,
|
||||
"top_p": 0.95,
|
||||
}
|
||||
|
||||
# Generate a response with the specified settings
|
||||
response = generate(
|
||||
model=model,
|
||||
@ -38,5 +30,4 @@ response = generate(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
verbose=verbose,
|
||||
**generation_args,
|
||||
)
|
||||
|
@ -7,6 +7,7 @@ import sys
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import QuantizedKVCache, load_prompt_cache
|
||||
from .sample_utils import make_sampler
|
||||
from .utils import generate, load
|
||||
|
||||
DEFAULT_PROMPT = "hello"
|
||||
@ -218,16 +219,14 @@ def main():
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
max_tokens=args.max_tokens,
|
||||
verbose=args.verbose,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
min_p=args.min_p,
|
||||
min_tokens_to_keep=args.min_tokens_to_keep,
|
||||
sampler=sampler,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
kv_bits=args.kv_bits,
|
||||
|
@ -27,6 +27,7 @@ from huggingface_hub import scan_cache_dir
|
||||
|
||||
from ._version import __version__
|
||||
from .models.cache import make_prompt_cache
|
||||
from .sample_utils import make_logits_processors, make_sampler
|
||||
from .utils import load, stream_generate
|
||||
|
||||
|
||||
@ -464,15 +465,17 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
text = ""
|
||||
tic = time.perf_counter()
|
||||
sampler = make_sampler(self.temperature)
|
||||
logits_processors = make_logits_processors(
|
||||
self.logit_bias, self.repetition_penalty, self.repetition_context_size
|
||||
)
|
||||
for gen_response in stream_generate(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
temp=self.temperature,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
repetition_context_size=self.repetition_context_size,
|
||||
logit_bias=self.logit_bias,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=self.prompt_cache.cache,
|
||||
):
|
||||
segment = gen_response.text
|
||||
|
@ -73,16 +73,16 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
|
||||
def reset(self):
|
||||
self.offset = 0
|
||||
self._tokens = []
|
||||
self.tokens = []
|
||||
self._text = ""
|
||||
self._current_tokens = []
|
||||
self._current_text = ""
|
||||
|
||||
def add_token(self, token):
|
||||
self._current_tokens.append(token)
|
||||
self.tokens.append(token)
|
||||
|
||||
def finalize(self):
|
||||
self._tokens.extend(self._current_tokens)
|
||||
self._text += self._tokenizer.decode(self._current_tokens)
|
||||
self._current_tokens = []
|
||||
self._current_text = ""
|
||||
@ -97,16 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
):
|
||||
self._current_text = self._current_text[:-1]
|
||||
if self._current_text and self._current_text[-1] == "\n":
|
||||
self._tokens.extend(self._current_tokens)
|
||||
self._text += self._current_text
|
||||
self._current_tokens.clear()
|
||||
self._current_text = ""
|
||||
return self._text + self._current_text
|
||||
|
||||
@property
|
||||
def tokens(self):
|
||||
return self._tokens
|
||||
|
||||
|
||||
class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""A streaming detokenizer for SPM models.
|
||||
@ -143,6 +138,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
self.text += text
|
||||
|
||||
def add_token(self, token):
|
||||
self.tokens.append(token)
|
||||
v = self.tokenmap[token]
|
||||
if v.startswith(self._sep):
|
||||
self._flush()
|
||||
@ -200,6 +196,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
return current_text
|
||||
|
||||
def add_token(self, token):
|
||||
self.tokens.append(token)
|
||||
v = self.tokenmap[token]
|
||||
is_added = token in self._added_ids
|
||||
if is_added or self._byte_decoder[v[0]] == 32:
|
||||
|
@ -182,20 +182,21 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
temp: float = 0.0,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = 20,
|
||||
top_p: float = 1.0,
|
||||
min_p: float = 0.0,
|
||||
min_tokens_to_keep: int = 1,
|
||||
prefill_step_size: int = 512,
|
||||
*,
|
||||
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
max_kv_size: Optional[int] = None,
|
||||
prompt_cache: Optional[Any] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
prefill_step_size: int = 512,
|
||||
kv_bits: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
temp: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
min_p: Optional[float] = None,
|
||||
min_tokens_to_keep: Optional[int] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
@ -203,32 +204,21 @@ def generate_step(
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
Default: ``0``.
|
||||
repetition_penalty (float, optional): The penalty factor for repeating
|
||||
tokens.
|
||||
repetition_context_size (int, optional): The number of tokens to
|
||||
consider for repetition penalty. Default: ``20``.
|
||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||
more less likely words.
|
||||
min_p (float, optional): The minimum value (scaled by the top token's
|
||||
probability) that a token probability must have to be considered.
|
||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||
be filtered by min_p sampling.
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||
entries (except the first 4 tokens) will be overwritten.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||
token from a vector of log probabilities. Default: ``None``.
|
||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||
None implies no cache quantization. Default: ``None``.
|
||||
None implies no cache quantization. Default: ``None``.
|
||||
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
|
||||
Yields:
|
||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||
@ -246,10 +236,22 @@ def generate_step(
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
|
||||
logits_processors = logits_processors or []
|
||||
logits_processors.extend(
|
||||
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
|
||||
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
|
||||
print(
|
||||
"[Warning] Specifying sampling arguments to ``generate_step`` is "
|
||||
"deprecated. Pass in a ``sampler`` instead."
|
||||
)
|
||||
if repetition_penalty is not None:
|
||||
print(
|
||||
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
|
||||
"Pass in ``logits_processors`` instead."
|
||||
)
|
||||
|
||||
sampler = sampler or make_sampler(
|
||||
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
|
||||
)
|
||||
logits_processors = logits_processors or make_logits_processors(
|
||||
None, repetition_penalty, repetition_context_size or 20
|
||||
)
|
||||
|
||||
def _step(y):
|
||||
@ -385,7 +387,10 @@ def generate(
|
||||
See :func:`stream_generate` for more details.
|
||||
"""
|
||||
if formatter is not None:
|
||||
print("Text formatting is deprecated and will be removed in the next version.")
|
||||
print(
|
||||
"[Warning] Text formatting is deprecated and no longer used. "
|
||||
"The argument will be removed in a future version."
|
||||
)
|
||||
if verbose:
|
||||
print("=" * 10)
|
||||
print("Prompt:", prompt)
|
||||
|
@ -34,10 +34,11 @@ class TestTokenizers(unittest.TestCase):
|
||||
detokenizer = tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
text = ""
|
||||
for t in tokens:
|
||||
for e, t in enumerate(tokens):
|
||||
detokenizer.add_token(t)
|
||||
seg = detokenizer.last_segment
|
||||
text += seg
|
||||
self.assertEqual(detokenizer.tokens, tokens[: e + 1])
|
||||
detokenizer.finalize()
|
||||
text += detokenizer.last_segment
|
||||
self.assertEqual(text, expected_text)
|
||||
|
Loading…
Reference in New Issue
Block a user