Generation refactor: part 2 (#1099)

* unify with stream_generate

* fixes

* nit

* some cleanup, warnings, tests

* fix test + faster min p + test

* version
This commit is contained in:
Awni Hannun 2024-11-23 11:47:06 -08:00 committed by GitHub
parent 004eb4cc9d
commit 0f135396ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 184 additions and 197 deletions

View File

@ -61,7 +61,7 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
response = generate(model, tokenizer, prompt=prompt, verbose=True) text = generate(model, tokenizer, prompt=prompt, verbose=True)
``` ```
To see a description of all the arguments you can do: To see a description of all the arguments you can do:
@ -100,8 +100,9 @@ To see a description of all the arguments you can do:
#### Streaming #### Streaming
For streaming generation, use the `stream_generate` function. This returns a For streaming generation, use the `stream_generate` function. This yields
generator object which streams the output text, token, and log probabilities. a generation response object.
For example, For example,
```python ```python
@ -117,8 +118,8 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512): for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True) print(response.text, end="", flush=True)
print() print()
``` ```

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.19.3" __version__ = "0.20.0"

View File

@ -5,7 +5,8 @@ import json
import mlx.core as mx import mlx.core as mx
from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache from .models.cache import make_prompt_cache
from .sample_utils import make_sampler
from .utils import load, stream_generate from .utils import load, stream_generate
DEFAULT_TEMP = 0.0 DEFAULT_TEMP = 0.0
@ -74,16 +75,15 @@ def main():
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
for response, *_ in stream_generate( for response in stream_generate(
model, model,
tokenizer, tokenizer,
prompt, prompt,
args.max_tokens, args.max_tokens,
temp=args.temp, sampler=make_sampler(args.temp, args.top_p),
top_p=args.top_p,
prompt_cache=prompt_cache, prompt_cache=prompt_cache,
): ):
print(response, flush=True, end="") print(response.text, flush=True, end="")
print() print()

View File

@ -42,7 +42,6 @@ response = generate(
tokenizer, tokenizer,
prompt=prompt, prompt=prompt,
verbose=True, verbose=True,
temp=0.0,
prompt_cache=prompt_cache, prompt_cache=prompt_cache,
) )

View File

@ -23,14 +23,6 @@ max_tokens = 1_000
# Specify if tokens and timing information will be printed # Specify if tokens and timing information will be printed
verbose = True verbose = True
# Some optional arguments for causal language model generation
generation_args = {
"temp": 0.7,
"repetition_penalty": 1.2,
"repetition_context_size": 20,
"top_p": 0.95,
}
# Generate a response with the specified settings # Generate a response with the specified settings
response = generate( response = generate(
model=model, model=model,
@ -38,5 +30,4 @@ response = generate(
prompt=prompt, prompt=prompt,
max_tokens=max_tokens, max_tokens=max_tokens,
verbose=verbose, verbose=verbose,
**generation_args,
) )

View File

@ -7,6 +7,7 @@ import sys
import mlx.core as mx import mlx.core as mx
from .models.cache import QuantizedKVCache, load_prompt_cache from .models.cache import QuantizedKVCache, load_prompt_cache
from .sample_utils import make_sampler
from .utils import generate, load from .utils import generate, load
DEFAULT_PROMPT = "hello" DEFAULT_PROMPT = "hello"
@ -97,11 +98,6 @@ def setup_arg_parser():
default=True, default=True,
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'", help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
) )
parser.add_argument(
"--colorize",
action="store_true",
help="Colorize output based on T[0] probability",
)
parser.add_argument( parser.add_argument(
"--max-kv-size", "--max-kv-size",
type=int, type=int,
@ -137,33 +133,6 @@ def setup_arg_parser():
return parser return parser
def colorprint(color, s):
color_codes = {
"black": 30,
"red": 31,
"green": 32,
"yellow": 33,
"blue": 34,
"magenta": 35,
"cyan": 36,
"white": 39,
}
ccode = color_codes.get(color, 30)
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
def colorprint_by_t0(s, t0):
if t0 > 0.95:
color = "white"
elif t0 > 0.70:
color = "green"
elif t0 > 0.30:
color = "yellow"
else:
color = "red"
colorprint(color, s)
def main(): def main():
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
@ -250,21 +219,14 @@ def main():
else: else:
prompt = args.prompt prompt = args.prompt
if args.colorize and not args.verbose: sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None
response = generate( response = generate(
model, model,
tokenizer, tokenizer,
prompt, prompt,
args.max_tokens, max_tokens=args.max_tokens,
verbose=args.verbose, verbose=args.verbose,
formatter=formatter, sampler=sampler,
temp=args.temp,
top_p=args.top_p,
min_p=args.min_p,
min_tokens_to_keep=args.min_tokens_to_keep,
max_kv_size=args.max_kv_size, max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None, prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits, kv_bits=args.kv_bits,

View File

@ -1,5 +1,6 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import math
from functools import partial from functools import partial
from typing import Callable, Dict, Optional from typing import Callable, Dict, Optional
@ -80,7 +81,7 @@ def make_logits_processors(
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling( def min_p_sampling(
logits: mx.array, logprobs: mx.array,
min_p: float, min_p: float,
min_tokens_to_keep: int = 1, min_tokens_to_keep: int = 1,
temperature=1.0, temperature=1.0,
@ -93,7 +94,7 @@ def min_p_sampling(
aggressive given a very high-probability token. aggressive given a very high-probability token.
Args: Args:
logits: The logits from the model's output. logprobs: A vector of log probabilities.
min_p (float): Minimum token probability. Typical values are in the min_p (float): Minimum token probability. Typical values are in the
0.01-0.2 range, comparably selective as setting `top_p` in the 0.01-0.2 range, comparably selective as setting `top_p` in the
0.99-0.8 range. 0.99-0.8 range.
@ -111,28 +112,27 @@ def min_p_sampling(
) )
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
# Softmax probabilities logprobs = logprobs * (1 / temperature)
probs = mx.softmax(logits * (1 / temperature), axis=-1)
# Indices sorted in decreasing order # Indices sorted in decreasing order
sorted_indices = mx.argsort(-logits).squeeze(0) sorted_indices = mx.argsort(-logprobs).squeeze(0)
sorted_probs = probs[..., sorted_indices] sorted_logprobs = logprobs[..., sorted_indices]
# Top probability # Top probability
top_probs = probs[..., sorted_indices[0]] top_logprobs = logprobs[..., sorted_indices[0]]
# Calculate the min_p threshold # Calculate the min_p threshold
scaled_min_p = min_p * top_probs scaled_min_p = top_logprobs + math.log(min_p)
# Mask tokens that have a probability less than the scaled min_p # Mask tokens that have a probability less than the scaled min_p
tokens_to_remove = sorted_probs < scaled_min_p tokens_to_remove = sorted_logprobs < scaled_min_p
tokens_to_remove[..., :min_tokens_to_keep] = False tokens_to_remove[..., :min_tokens_to_keep] = False
# Create pool of tokens with probability less than scaled min_p # Create pool of tokens with probability less than scaled min_p
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs) selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
# Return sampled token # Return sampled token
sorted_token = mx.random.categorical(mx.log(selected_probs)) sorted_token = mx.random.categorical(selected_logprobs)
return sorted_indices[sorted_token] return sorted_indices[sorted_token]

View File

@ -27,6 +27,7 @@ from huggingface_hub import scan_cache_dir
from ._version import __version__ from ._version import __version__
from .models.cache import make_prompt_cache from .models.cache import make_prompt_cache
from .sample_utils import make_logits_processors, make_sampler
from .utils import load, stream_generate from .utils import load, stream_generate
@ -464,25 +465,24 @@ class APIHandler(BaseHTTPRequestHandler):
text = "" text = ""
tic = time.perf_counter() tic = time.perf_counter()
for n, (segment, token, logprobs) in enumerate( sampler = make_sampler(self.temperature)
stream_generate( logits_processors = make_logits_processors(
model=self.model, self.logit_bias, self.repetition_penalty, self.repetition_context_size
tokenizer=self.tokenizer, )
prompt=prompt, for gen_response in stream_generate(
max_tokens=self.max_tokens, model=self.model,
temp=self.temperature, tokenizer=self.tokenizer,
repetition_penalty=self.repetition_penalty, prompt=prompt,
repetition_context_size=self.repetition_context_size, max_tokens=self.max_tokens,
logit_bias=self.logit_bias, sampler=sampler,
prompt_cache=self.prompt_cache.cache, logits_processors=logits_processors,
), prompt_cache=self.prompt_cache.cache,
): ):
if n == 0: segment = gen_response.text
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
text += segment text += segment
logging.debug(text) logging.debug(text)
token = gen_response.token
logprobs = gen_response.logprobs
tokens.append(token) tokens.append(token)
if self.logprobs > 0: if self.logprobs > 0:
@ -523,13 +523,9 @@ class APIHandler(BaseHTTPRequestHandler):
self.prompt_cache.tokens.extend(tokens) self.prompt_cache.tokens.extend(tokens)
gen_time = time.perf_counter() - tic logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec")
prompt_tps = len(prompt) / prompt_time logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec")
gen_tps = len(tokens) / gen_time logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB")
peak_mem = mx.metal.get_peak_memory() / 1e9
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
if self.stream: if self.stream:
response = self.generate_response(segment, finish_reason) response = self.generate_response(segment, finish_reason)

View File

@ -73,16 +73,16 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
def reset(self): def reset(self):
self.offset = 0 self.offset = 0
self._tokens = [] self.tokens = []
self._text = "" self._text = ""
self._current_tokens = [] self._current_tokens = []
self._current_text = "" self._current_text = ""
def add_token(self, token): def add_token(self, token):
self._current_tokens.append(token) self._current_tokens.append(token)
self.tokens.append(token)
def finalize(self): def finalize(self):
self._tokens.extend(self._current_tokens)
self._text += self._tokenizer.decode(self._current_tokens) self._text += self._tokenizer.decode(self._current_tokens)
self._current_tokens = [] self._current_tokens = []
self._current_text = "" self._current_text = ""
@ -97,16 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
): ):
self._current_text = self._current_text[:-1] self._current_text = self._current_text[:-1]
if self._current_text and self._current_text[-1] == "\n": if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens)
self._text += self._current_text self._text += self._current_text
self._current_tokens.clear() self._current_tokens.clear()
self._current_text = "" self._current_text = ""
return self._text + self._current_text return self._text + self._current_text
@property
def tokens(self):
return self._tokens
class SPMStreamingDetokenizer(StreamingDetokenizer): class SPMStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for SPM models. """A streaming detokenizer for SPM models.
@ -143,6 +138,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
self.text += text self.text += text
def add_token(self, token): def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token] v = self.tokenmap[token]
if v.startswith(self._sep): if v.startswith(self._sep):
self._flush() self._flush()
@ -200,6 +196,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
return current_text return current_text
def add_token(self, token): def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token] v = self.tokenmap[token]
is_added = token in self._added_ids is_added = token in self._added_ids
if is_added or self._byte_decoder[v[0]] == 32: if is_added or self._byte_decoder[v[0]] == 32:

View File

@ -8,6 +8,7 @@ import json
import logging import logging
import shutil import shutil
import time import time
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
@ -44,6 +45,32 @@ class ModelNotFoundError(Exception):
super().__init__(self.message) super().__init__(self.message)
@dataclass
class GenerationResponse:
"""
The output of :func:`stream_generate`.
Args:
text (str): The next segment of decoded text. This can be an empty string.
token (int): The next token.
logprobs (mx.array): A vector of log probabilities.
prompt_tokens (int): The number of tokens in the prompt.
prompt_tps (float): The prompt processing tokens-per-second.
generation_tokens (int): The number of generated tokens.
generation_tps (float): The tokens-per-second for generation.
peak_memory (float): The peak memory used so far in GB.
"""
text: str
token: int
logprobs: mx.array
prompt_tokens: int
prompt_tps: float
generation_tokens: int
generation_tps: float
peak_memory: float
@contextlib.contextmanager @contextlib.contextmanager
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
""" """
@ -155,20 +182,21 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
def generate_step( def generate_step(
prompt: mx.array, prompt: mx.array,
model: nn.Module, model: nn.Module,
temp: float = 0.0, *,
repetition_penalty: Optional[float] = None, sampler: Optional[Callable[mx.array, mx.array]] = None,
repetition_context_size: Optional[int] = 20, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
top_p: float = 1.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None, max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None, prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None, prefill_step_size: int = 512,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
kv_bits: Optional[int] = None, kv_bits: Optional[int] = None,
kv_group_size: int = 64, kv_group_size: int = 64,
quantized_kv_start: int = 0, quantized_kv_start: int = 0,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
min_tokens_to_keep: Optional[int] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -176,32 +204,21 @@ def generate_step(
Args: Args:
prompt (mx.array): The input prompt. prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
prefill_step_size (int): Step size for processing the prompt. prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten. entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place. provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias. sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed A list of functions that take tokens and logits and return the processed
logits. Default: ``None``. logits. Default: ``None``.
kv_bits (int, optional): Number of bits to use for KV cache quantization. kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``. None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``. kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache. quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``. when ``kv_bits`` is non-None. Default: ``0``.
Yields: Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities. Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
@ -219,10 +236,22 @@ def generate_step(
elif len(prompt_cache) != len(model.layers): elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.") raise ValueError("Wrong number of layers in the prompt cache.")
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) if temp is not None or top_p is not None or min_tokens_to_keep is not None:
logits_processors = logits_processors or [] print(
logits_processors.extend( "[Warning] Specifying sampling arguments to ``generate_step`` is "
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size) "deprecated. Pass in a ``sampler`` instead."
)
if repetition_penalty is not None:
print(
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
"Pass in ``logits_processors`` instead."
)
sampler = sampler or make_sampler(
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
)
logits_processors = logits_processors or make_logits_processors(
None, repetition_penalty, repetition_context_size or 20
) )
def _step(y): def _step(y):
@ -290,17 +319,20 @@ def stream_generate(
if not isinstance(tokenizer, TokenizerWrapper): if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer) tokenizer = TokenizerWrapper(tokenizer)
prompt_tokens = mx.array( prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt))
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
)
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
with wired_limit(model, [generation_stream]): with wired_limit(model, [generation_stream]):
detokenizer.reset() detokenizer.reset()
for n, (token, logits) in zip( tic = time.perf_counter()
for n, (token, logprobs) in zip(
range(max_tokens), range(max_tokens),
generate_step(prompt_tokens, model, **kwargs), generate_step(prompt, model, **kwargs),
): ):
if n == 0:
prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time
tic = time.perf_counter()
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
break break
@ -309,17 +341,34 @@ def stream_generate(
if n == (max_tokens - 1): if n == (max_tokens - 1):
break break
yield detokenizer.last_segment, token, logits yield GenerationResponse(
text=detokenizer.last_segment,
token=token,
logprobs=logprobs,
prompt_tokens=prompt.size,
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.metal.get_peak_memory() / 1e9,
)
detokenizer.finalize() detokenizer.finalize()
yield detokenizer.last_segment, token, logits yield GenerationResponse(
text=detokenizer.last_segment,
token=token,
logprobs=logprobs,
prompt_tokens=prompt.size,
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.metal.get_peak_memory() / 1e9,
)
def generate( def generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str, prompt: str,
max_tokens: int = 100,
verbose: bool = False, verbose: bool = False,
formatter: Optional[Callable] = None, formatter: Optional[Callable] = None,
**kwargs, **kwargs,
@ -334,64 +383,40 @@ def generate(
max_tokens (int): The maximum number of tokens. Default: ``100``. max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information. verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``. Default: ``False``.
formatter (Optional[Callable]): A function which takes a token and a kwargs: The remaining options get passed to :func:`stream_generate`.
probability and displays it. See :func:`stream_generate` for more details.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
""" """
if not isinstance(tokenizer, TokenizerWrapper): if formatter is not None:
tokenizer = TokenizerWrapper(tokenizer) print(
"[Warning] Text formatting is deprecated and no longer used. "
"The argument will be removed in a future version."
)
if verbose: if verbose:
print("=" * 10) print("=" * 10)
print("Prompt:", prompt) print("Prompt:", prompt)
prompt_tokens = mx.array(tokenizer.encode(prompt)) text = ""
detokenizer = tokenizer.detokenizer for response in stream_generate(model, tokenizer, prompt, **kwargs):
with wired_limit(model, [generation_stream]):
tic = time.perf_counter()
detokenizer.reset()
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
if verbose:
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)
token_count = n + 1
detokenizer.finalize()
if verbose: if verbose:
gen_time = time.perf_counter() - tic print(response.text, end="", flush=True)
print(detokenizer.last_segment, flush=True) text += response.text
print("=" * 10)
if token_count == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print(
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
)
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 1e9
print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text if verbose:
print()
print("=" * 10)
if len(text) == 0:
print("No text generated for this prompt")
return
print(
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec"
)
print(
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec"
)
print(f"Peak memory: {response.peak_memory:.3f} GB")
return text
def load_config(model_path: Path) -> dict: def load_config(model_path: Path) -> dict:

View File

@ -2,6 +2,7 @@
import unittest import unittest
from mlx_lm.sample_utils import make_logits_processors
from mlx_lm.utils import generate, load from mlx_lm.utils import generate, load
@ -25,8 +26,8 @@ class TestGenerate(unittest.TestCase):
self.tokenizer, self.tokenizer,
"hello", "hello",
max_tokens=5, max_tokens=5,
logits_processors=make_logits_processors(logit_bias),
verbose=False, verbose=False,
logit_bias=logit_bias,
) )
self.assertEqual(text, "!!!!!") self.assertEqual(text, "!!!!!")

View File

@ -1,10 +1,10 @@
import unittest import unittest
import mlx.core as mx import mlx.core as mx
from mlx_lm.sample_utils import top_p_sampling from mlx_lm.sample_utils import min_p_sampling, top_p_sampling
class TestSamplingUtils(unittest.TestCase): class TestSampleUtils(unittest.TestCase):
def test_top_p_sampling(self): def test_top_p_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs) logits = mx.log(probs)
@ -28,6 +28,20 @@ class TestSamplingUtils(unittest.TestCase):
token = top_p_sampling(logits, 0.95, temperature).item() token = top_p_sampling(logits, 0.95, temperature).item()
self.assertTrue(token in (1, 2, 3)) self.assertTrue(token in (1, 2, 3))
def test_min_p_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
temperature = 1.0
token = min_p_sampling(logits, 0.8)
self.assertEqual(token, 0)
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
temperature = 1.0
for _ in range(5):
token = min_p_sampling(logits, 0.05)
self.assertTrue(token in (0, 3))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -34,10 +34,11 @@ class TestTokenizers(unittest.TestCase):
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
detokenizer.reset() detokenizer.reset()
text = "" text = ""
for t in tokens: for e, t in enumerate(tokens):
detokenizer.add_token(t) detokenizer.add_token(t)
seg = detokenizer.last_segment seg = detokenizer.last_segment
text += seg text += seg
self.assertEqual(detokenizer.tokens, tokens[: e + 1])
detokenizer.finalize() detokenizer.finalize()
text += detokenizer.last_segment text += detokenizer.last_segment
self.assertEqual(text, expected_text) self.assertEqual(text, expected_text)