mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Generation refactor: part 2 (#1099)
* unify with stream_generate * fixes * nit * some cleanup, warnings, tests * fix test + faster min p + test * version
This commit is contained in:
parent
004eb4cc9d
commit
0f135396ae
@ -61,7 +61,7 @@ prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
response = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||
text = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||
```
|
||||
|
||||
To see a description of all the arguments you can do:
|
||||
@ -100,8 +100,9 @@ To see a description of all the arguments you can do:
|
||||
|
||||
#### Streaming
|
||||
|
||||
For streaming generation, use the `stream_generate` function. This returns a
|
||||
generator object which streams the output text, token, and log probabilities.
|
||||
For streaming generation, use the `stream_generate` function. This yields
|
||||
a generation response object.
|
||||
|
||||
For example,
|
||||
|
||||
```python
|
||||
@ -117,8 +118,8 @@ prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||
print(t, end="", flush=True)
|
||||
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||
print(response.text, end="", flush=True)
|
||||
print()
|
||||
```
|
||||
|
||||
|
@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.19.3"
|
||||
__version__ = "0.20.0"
|
||||
|
@ -5,7 +5,8 @@ import json
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
|
||||
from .models.cache import make_prompt_cache
|
||||
from .sample_utils import make_sampler
|
||||
from .utils import load, stream_generate
|
||||
|
||||
DEFAULT_TEMP = 0.0
|
||||
@ -74,16 +75,15 @@ def main():
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
for response, *_ in stream_generate(
|
||||
for response in stream_generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
args.max_tokens,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
sampler=make_sampler(args.temp, args.top_p),
|
||||
prompt_cache=prompt_cache,
|
||||
):
|
||||
print(response, flush=True, end="")
|
||||
print(response.text, flush=True, end="")
|
||||
print()
|
||||
|
||||
|
||||
|
@ -42,7 +42,6 @@ response = generate(
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
verbose=True,
|
||||
temp=0.0,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
||||
|
||||
|
@ -23,14 +23,6 @@ max_tokens = 1_000
|
||||
# Specify if tokens and timing information will be printed
|
||||
verbose = True
|
||||
|
||||
# Some optional arguments for causal language model generation
|
||||
generation_args = {
|
||||
"temp": 0.7,
|
||||
"repetition_penalty": 1.2,
|
||||
"repetition_context_size": 20,
|
||||
"top_p": 0.95,
|
||||
}
|
||||
|
||||
# Generate a response with the specified settings
|
||||
response = generate(
|
||||
model=model,
|
||||
@ -38,5 +30,4 @@ response = generate(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
verbose=verbose,
|
||||
**generation_args,
|
||||
)
|
||||
|
@ -7,6 +7,7 @@ import sys
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import QuantizedKVCache, load_prompt_cache
|
||||
from .sample_utils import make_sampler
|
||||
from .utils import generate, load
|
||||
|
||||
DEFAULT_PROMPT = "hello"
|
||||
@ -97,11 +98,6 @@ def setup_arg_parser():
|
||||
default=True,
|
||||
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--colorize",
|
||||
action="store_true",
|
||||
help="Colorize output based on T[0] probability",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-kv-size",
|
||||
type=int,
|
||||
@ -137,33 +133,6 @@ def setup_arg_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def colorprint(color, s):
|
||||
color_codes = {
|
||||
"black": 30,
|
||||
"red": 31,
|
||||
"green": 32,
|
||||
"yellow": 33,
|
||||
"blue": 34,
|
||||
"magenta": 35,
|
||||
"cyan": 36,
|
||||
"white": 39,
|
||||
}
|
||||
ccode = color_codes.get(color, 30)
|
||||
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
|
||||
|
||||
|
||||
def colorprint_by_t0(s, t0):
|
||||
if t0 > 0.95:
|
||||
color = "white"
|
||||
elif t0 > 0.70:
|
||||
color = "green"
|
||||
elif t0 > 0.30:
|
||||
color = "yellow"
|
||||
else:
|
||||
color = "red"
|
||||
colorprint(color, s)
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
@ -250,21 +219,14 @@ def main():
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
if args.colorize and not args.verbose:
|
||||
raise ValueError("Cannot use --colorize with --verbose=False")
|
||||
formatter = colorprint_by_t0 if args.colorize else None
|
||||
|
||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
args.max_tokens,
|
||||
max_tokens=args.max_tokens,
|
||||
verbose=args.verbose,
|
||||
formatter=formatter,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
min_p=args.min_p,
|
||||
min_tokens_to_keep=args.min_tokens_to_keep,
|
||||
sampler=sampler,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
kv_bits=args.kv_bits,
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
@ -80,7 +81,7 @@ def make_logits_processors(
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def min_p_sampling(
|
||||
logits: mx.array,
|
||||
logprobs: mx.array,
|
||||
min_p: float,
|
||||
min_tokens_to_keep: int = 1,
|
||||
temperature=1.0,
|
||||
@ -93,7 +94,7 @@ def min_p_sampling(
|
||||
aggressive given a very high-probability token.
|
||||
|
||||
Args:
|
||||
logits: The logits from the model's output.
|
||||
logprobs: A vector of log probabilities.
|
||||
min_p (float): Minimum token probability. Typical values are in the
|
||||
0.01-0.2 range, comparably selective as setting `top_p` in the
|
||||
0.99-0.8 range.
|
||||
@ -111,28 +112,27 @@ def min_p_sampling(
|
||||
)
|
||||
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
|
||||
|
||||
# Softmax probabilities
|
||||
probs = mx.softmax(logits * (1 / temperature), axis=-1)
|
||||
logprobs = logprobs * (1 / temperature)
|
||||
|
||||
# Indices sorted in decreasing order
|
||||
sorted_indices = mx.argsort(-logits).squeeze(0)
|
||||
sorted_probs = probs[..., sorted_indices]
|
||||
sorted_indices = mx.argsort(-logprobs).squeeze(0)
|
||||
sorted_logprobs = logprobs[..., sorted_indices]
|
||||
|
||||
# Top probability
|
||||
top_probs = probs[..., sorted_indices[0]]
|
||||
top_logprobs = logprobs[..., sorted_indices[0]]
|
||||
|
||||
# Calculate the min_p threshold
|
||||
scaled_min_p = min_p * top_probs
|
||||
scaled_min_p = top_logprobs + math.log(min_p)
|
||||
|
||||
# Mask tokens that have a probability less than the scaled min_p
|
||||
tokens_to_remove = sorted_probs < scaled_min_p
|
||||
tokens_to_remove = sorted_logprobs < scaled_min_p
|
||||
tokens_to_remove[..., :min_tokens_to_keep] = False
|
||||
|
||||
# Create pool of tokens with probability less than scaled min_p
|
||||
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs)
|
||||
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
|
||||
|
||||
# Return sampled token
|
||||
sorted_token = mx.random.categorical(mx.log(selected_probs))
|
||||
sorted_token = mx.random.categorical(selected_logprobs)
|
||||
return sorted_indices[sorted_token]
|
||||
|
||||
|
||||
|
@ -27,6 +27,7 @@ from huggingface_hub import scan_cache_dir
|
||||
|
||||
from ._version import __version__
|
||||
from .models.cache import make_prompt_cache
|
||||
from .sample_utils import make_logits_processors, make_sampler
|
||||
from .utils import load, stream_generate
|
||||
|
||||
|
||||
@ -464,25 +465,24 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
text = ""
|
||||
tic = time.perf_counter()
|
||||
for n, (segment, token, logprobs) in enumerate(
|
||||
stream_generate(
|
||||
sampler = make_sampler(self.temperature)
|
||||
logits_processors = make_logits_processors(
|
||||
self.logit_bias, self.repetition_penalty, self.repetition_context_size
|
||||
)
|
||||
for gen_response in stream_generate(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
temp=self.temperature,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
repetition_context_size=self.repetition_context_size,
|
||||
logit_bias=self.logit_bias,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=self.prompt_cache.cache,
|
||||
),
|
||||
):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
tic = time.perf_counter()
|
||||
|
||||
segment = gen_response.text
|
||||
text += segment
|
||||
logging.debug(text)
|
||||
token = gen_response.token
|
||||
logprobs = gen_response.logprobs
|
||||
tokens.append(token)
|
||||
|
||||
if self.logprobs > 0:
|
||||
@ -523,13 +523,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
self.prompt_cache.tokens.extend(tokens)
|
||||
|
||||
gen_time = time.perf_counter() - tic
|
||||
prompt_tps = len(prompt) / prompt_time
|
||||
gen_tps = len(tokens) / gen_time
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
|
||||
logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec")
|
||||
logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec")
|
||||
logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB")
|
||||
|
||||
if self.stream:
|
||||
response = self.generate_response(segment, finish_reason)
|
||||
|
@ -73,16 +73,16 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
|
||||
def reset(self):
|
||||
self.offset = 0
|
||||
self._tokens = []
|
||||
self.tokens = []
|
||||
self._text = ""
|
||||
self._current_tokens = []
|
||||
self._current_text = ""
|
||||
|
||||
def add_token(self, token):
|
||||
self._current_tokens.append(token)
|
||||
self.tokens.append(token)
|
||||
|
||||
def finalize(self):
|
||||
self._tokens.extend(self._current_tokens)
|
||||
self._text += self._tokenizer.decode(self._current_tokens)
|
||||
self._current_tokens = []
|
||||
self._current_text = ""
|
||||
@ -97,16 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
):
|
||||
self._current_text = self._current_text[:-1]
|
||||
if self._current_text and self._current_text[-1] == "\n":
|
||||
self._tokens.extend(self._current_tokens)
|
||||
self._text += self._current_text
|
||||
self._current_tokens.clear()
|
||||
self._current_text = ""
|
||||
return self._text + self._current_text
|
||||
|
||||
@property
|
||||
def tokens(self):
|
||||
return self._tokens
|
||||
|
||||
|
||||
class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""A streaming detokenizer for SPM models.
|
||||
@ -143,6 +138,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
self.text += text
|
||||
|
||||
def add_token(self, token):
|
||||
self.tokens.append(token)
|
||||
v = self.tokenmap[token]
|
||||
if v.startswith(self._sep):
|
||||
self._flush()
|
||||
@ -200,6 +196,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
return current_text
|
||||
|
||||
def add_token(self, token):
|
||||
self.tokens.append(token)
|
||||
v = self.tokenmap[token]
|
||||
is_added = token in self._added_ids
|
||||
if is_added or self._byte_decoder[v[0]] == 32:
|
||||
|
@ -8,6 +8,7 @@ import json
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
||||
@ -44,6 +45,32 @@ class ModelNotFoundError(Exception):
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationResponse:
|
||||
"""
|
||||
The output of :func:`stream_generate`.
|
||||
|
||||
Args:
|
||||
text (str): The next segment of decoded text. This can be an empty string.
|
||||
token (int): The next token.
|
||||
logprobs (mx.array): A vector of log probabilities.
|
||||
prompt_tokens (int): The number of tokens in the prompt.
|
||||
prompt_tps (float): The prompt processing tokens-per-second.
|
||||
generation_tokens (int): The number of generated tokens.
|
||||
generation_tps (float): The tokens-per-second for generation.
|
||||
peak_memory (float): The peak memory used so far in GB.
|
||||
"""
|
||||
|
||||
text: str
|
||||
token: int
|
||||
logprobs: mx.array
|
||||
prompt_tokens: int
|
||||
prompt_tps: float
|
||||
generation_tokens: int
|
||||
generation_tps: float
|
||||
peak_memory: float
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
||||
"""
|
||||
@ -155,20 +182,21 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
temp: float = 0.0,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = 20,
|
||||
top_p: float = 1.0,
|
||||
min_p: float = 0.0,
|
||||
min_tokens_to_keep: int = 1,
|
||||
prefill_step_size: int = 512,
|
||||
*,
|
||||
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
max_kv_size: Optional[int] = None,
|
||||
prompt_cache: Optional[Any] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
prefill_step_size: int = 512,
|
||||
kv_bits: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
temp: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
min_p: Optional[float] = None,
|
||||
min_tokens_to_keep: Optional[int] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
@ -176,24 +204,13 @@ def generate_step(
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
Default: ``0``.
|
||||
repetition_penalty (float, optional): The penalty factor for repeating
|
||||
tokens.
|
||||
repetition_context_size (int, optional): The number of tokens to
|
||||
consider for repetition penalty. Default: ``20``.
|
||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||
more less likely words.
|
||||
min_p (float, optional): The minimum value (scaled by the top token's
|
||||
probability) that a token probability must have to be considered.
|
||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||
be filtered by min_p sampling.
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||
entries (except the first 4 tokens) will be overwritten.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||
token from a vector of log probabilities. Default: ``None``.
|
||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
@ -219,10 +236,22 @@ def generate_step(
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
|
||||
logits_processors = logits_processors or []
|
||||
logits_processors.extend(
|
||||
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
|
||||
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
|
||||
print(
|
||||
"[Warning] Specifying sampling arguments to ``generate_step`` is "
|
||||
"deprecated. Pass in a ``sampler`` instead."
|
||||
)
|
||||
if repetition_penalty is not None:
|
||||
print(
|
||||
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
|
||||
"Pass in ``logits_processors`` instead."
|
||||
)
|
||||
|
||||
sampler = sampler or make_sampler(
|
||||
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
|
||||
)
|
||||
logits_processors = logits_processors or make_logits_processors(
|
||||
None, repetition_penalty, repetition_context_size or 20
|
||||
)
|
||||
|
||||
def _step(y):
|
||||
@ -290,17 +319,20 @@ def stream_generate(
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
prompt_tokens = mx.array(
|
||||
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
||||
)
|
||||
prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt))
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
with wired_limit(model, [generation_stream]):
|
||||
detokenizer.reset()
|
||||
for n, (token, logits) in zip(
|
||||
tic = time.perf_counter()
|
||||
for n, (token, logprobs) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
generate_step(prompt, model, **kwargs),
|
||||
):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
tic = time.perf_counter()
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
@ -309,17 +341,34 @@ def stream_generate(
|
||||
if n == (max_tokens - 1):
|
||||
break
|
||||
|
||||
yield detokenizer.last_segment, token, logits
|
||||
yield GenerationResponse(
|
||||
text=detokenizer.last_segment,
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
prompt_tokens=prompt.size,
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
||||
)
|
||||
|
||||
detokenizer.finalize()
|
||||
yield detokenizer.last_segment, token, logits
|
||||
yield GenerationResponse(
|
||||
text=detokenizer.last_segment,
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
prompt_tokens=prompt.size,
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
||||
)
|
||||
|
||||
|
||||
def generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: str,
|
||||
max_tokens: int = 100,
|
||||
verbose: bool = False,
|
||||
formatter: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
@ -334,64 +383,40 @@ def generate(
|
||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||
verbose (bool): If ``True``, print tokens and timing information.
|
||||
Default: ``False``.
|
||||
formatter (Optional[Callable]): A function which takes a token and a
|
||||
probability and displays it.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
kwargs: The remaining options get passed to :func:`stream_generate`.
|
||||
See :func:`stream_generate` for more details.
|
||||
"""
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
if formatter is not None:
|
||||
print(
|
||||
"[Warning] Text formatting is deprecated and no longer used. "
|
||||
"The argument will be removed in a future version."
|
||||
)
|
||||
if verbose:
|
||||
print("=" * 10)
|
||||
print("Prompt:", prompt)
|
||||
|
||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
with wired_limit(model, [generation_stream]):
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
for n, (token, logprobs) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
tic = time.perf_counter()
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
detokenizer.add_token(token)
|
||||
text = ""
|
||||
for response in stream_generate(model, tokenizer, prompt, **kwargs):
|
||||
if verbose:
|
||||
print(response.text, end="", flush=True)
|
||||
text += response.text
|
||||
|
||||
if verbose:
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
prob = mx.exp(logprobs[token]).item()
|
||||
formatter(detokenizer.last_segment, prob)
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
|
||||
token_count = n + 1
|
||||
detokenizer.finalize()
|
||||
|
||||
if verbose:
|
||||
gen_time = time.perf_counter() - tic
|
||||
print(detokenizer.last_segment, flush=True)
|
||||
print()
|
||||
print("=" * 10)
|
||||
if token_count == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
if len(text) == 0:
|
||||
print("No text generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt_tokens.size / prompt_time
|
||||
gen_tps = (token_count - 1) / gen_time
|
||||
print(
|
||||
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
|
||||
f"Prompt: {response.prompt_tokens} tokens, "
|
||||
f"{response.prompt_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||
|
||||
return detokenizer.text
|
||||
print(
|
||||
f"Generation: {response.generation_tokens} tokens, "
|
||||
f"{response.generation_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
print(f"Peak memory: {response.peak_memory:.3f} GB")
|
||||
return text
|
||||
|
||||
|
||||
def load_config(model_path: Path) -> dict:
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from mlx_lm.sample_utils import make_logits_processors
|
||||
from mlx_lm.utils import generate, load
|
||||
|
||||
|
||||
@ -25,8 +26,8 @@ class TestGenerate(unittest.TestCase):
|
||||
self.tokenizer,
|
||||
"hello",
|
||||
max_tokens=5,
|
||||
logits_processors=make_logits_processors(logit_bias),
|
||||
verbose=False,
|
||||
logit_bias=logit_bias,
|
||||
)
|
||||
self.assertEqual(text, "!!!!!")
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.sample_utils import top_p_sampling
|
||||
from mlx_lm.sample_utils import min_p_sampling, top_p_sampling
|
||||
|
||||
|
||||
class TestSamplingUtils(unittest.TestCase):
|
||||
class TestSampleUtils(unittest.TestCase):
|
||||
def test_top_p_sampling(self):
|
||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||
logits = mx.log(probs)
|
||||
@ -28,6 +28,20 @@ class TestSamplingUtils(unittest.TestCase):
|
||||
token = top_p_sampling(logits, 0.95, temperature).item()
|
||||
self.assertTrue(token in (1, 2, 3))
|
||||
|
||||
def test_min_p_sampling(self):
|
||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||
logits = mx.log(probs)
|
||||
temperature = 1.0
|
||||
token = min_p_sampling(logits, 0.8)
|
||||
self.assertEqual(token, 0)
|
||||
|
||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||
logits = mx.log(probs)
|
||||
temperature = 1.0
|
||||
for _ in range(5):
|
||||
token = min_p_sampling(logits, 0.05)
|
||||
self.assertTrue(token in (0, 3))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -34,10 +34,11 @@ class TestTokenizers(unittest.TestCase):
|
||||
detokenizer = tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
text = ""
|
||||
for t in tokens:
|
||||
for e, t in enumerate(tokens):
|
||||
detokenizer.add_token(t)
|
||||
seg = detokenizer.last_segment
|
||||
text += seg
|
||||
self.assertEqual(detokenizer.tokens, tokens[: e + 1])
|
||||
detokenizer.finalize()
|
||||
text += detokenizer.last_segment
|
||||
self.assertEqual(text, expected_text)
|
||||
|
Loading…
Reference in New Issue
Block a user