mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Generation refactor: part 2 (#1099)
* unify with stream_generate * fixes * nit * some cleanup, warnings, tests * fix test + faster min p + test * version
This commit is contained in:
parent
004eb4cc9d
commit
0f135396ae
@ -61,7 +61,7 @@ prompt = tokenizer.apply_chat_template(
|
|||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
response = generate(model, tokenizer, prompt=prompt, verbose=True)
|
text = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
To see a description of all the arguments you can do:
|
To see a description of all the arguments you can do:
|
||||||
@ -100,8 +100,9 @@ To see a description of all the arguments you can do:
|
|||||||
|
|
||||||
#### Streaming
|
#### Streaming
|
||||||
|
|
||||||
For streaming generation, use the `stream_generate` function. This returns a
|
For streaming generation, use the `stream_generate` function. This yields
|
||||||
generator object which streams the output text, token, and log probabilities.
|
a generation response object.
|
||||||
|
|
||||||
For example,
|
For example,
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@ -117,8 +118,8 @@ prompt = tokenizer.apply_chat_template(
|
|||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||||
print(t, end="", flush=True)
|
print(response.text, end="", flush=True)
|
||||||
print()
|
print()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.19.3"
|
__version__ = "0.20.0"
|
||||||
|
@ -5,7 +5,8 @@ import json
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
|
from .models.cache import make_prompt_cache
|
||||||
|
from .sample_utils import make_sampler
|
||||||
from .utils import load, stream_generate
|
from .utils import load, stream_generate
|
||||||
|
|
||||||
DEFAULT_TEMP = 0.0
|
DEFAULT_TEMP = 0.0
|
||||||
@ -74,16 +75,15 @@ def main():
|
|||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
for response, *_ in stream_generate(
|
for response in stream_generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
args.max_tokens,
|
args.max_tokens,
|
||||||
temp=args.temp,
|
sampler=make_sampler(args.temp, args.top_p),
|
||||||
top_p=args.top_p,
|
|
||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
):
|
):
|
||||||
print(response, flush=True, end="")
|
print(response.text, flush=True, end="")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,7 +42,6 @@ response = generate(
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
temp=0.0,
|
|
||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,14 +23,6 @@ max_tokens = 1_000
|
|||||||
# Specify if tokens and timing information will be printed
|
# Specify if tokens and timing information will be printed
|
||||||
verbose = True
|
verbose = True
|
||||||
|
|
||||||
# Some optional arguments for causal language model generation
|
|
||||||
generation_args = {
|
|
||||||
"temp": 0.7,
|
|
||||||
"repetition_penalty": 1.2,
|
|
||||||
"repetition_context_size": 20,
|
|
||||||
"top_p": 0.95,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Generate a response with the specified settings
|
# Generate a response with the specified settings
|
||||||
response = generate(
|
response = generate(
|
||||||
model=model,
|
model=model,
|
||||||
@ -38,5 +30,4 @@ response = generate(
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
**generation_args,
|
|
||||||
)
|
)
|
||||||
|
@ -7,6 +7,7 @@ import sys
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .models.cache import QuantizedKVCache, load_prompt_cache
|
from .models.cache import QuantizedKVCache, load_prompt_cache
|
||||||
|
from .sample_utils import make_sampler
|
||||||
from .utils import generate, load
|
from .utils import generate, load
|
||||||
|
|
||||||
DEFAULT_PROMPT = "hello"
|
DEFAULT_PROMPT = "hello"
|
||||||
@ -97,11 +98,6 @@ def setup_arg_parser():
|
|||||||
default=True,
|
default=True,
|
||||||
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
|
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--colorize",
|
|
||||||
action="store_true",
|
|
||||||
help="Colorize output based on T[0] probability",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-kv-size",
|
"--max-kv-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -137,33 +133,6 @@ def setup_arg_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def colorprint(color, s):
|
|
||||||
color_codes = {
|
|
||||||
"black": 30,
|
|
||||||
"red": 31,
|
|
||||||
"green": 32,
|
|
||||||
"yellow": 33,
|
|
||||||
"blue": 34,
|
|
||||||
"magenta": 35,
|
|
||||||
"cyan": 36,
|
|
||||||
"white": 39,
|
|
||||||
}
|
|
||||||
ccode = color_codes.get(color, 30)
|
|
||||||
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
def colorprint_by_t0(s, t0):
|
|
||||||
if t0 > 0.95:
|
|
||||||
color = "white"
|
|
||||||
elif t0 > 0.70:
|
|
||||||
color = "green"
|
|
||||||
elif t0 > 0.30:
|
|
||||||
color = "yellow"
|
|
||||||
else:
|
|
||||||
color = "red"
|
|
||||||
colorprint(color, s)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = setup_arg_parser()
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -250,21 +219,14 @@ def main():
|
|||||||
else:
|
else:
|
||||||
prompt = args.prompt
|
prompt = args.prompt
|
||||||
|
|
||||||
if args.colorize and not args.verbose:
|
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||||
raise ValueError("Cannot use --colorize with --verbose=False")
|
|
||||||
formatter = colorprint_by_t0 if args.colorize else None
|
|
||||||
|
|
||||||
response = generate(
|
response = generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
args.max_tokens,
|
max_tokens=args.max_tokens,
|
||||||
verbose=args.verbose,
|
verbose=args.verbose,
|
||||||
formatter=formatter,
|
sampler=sampler,
|
||||||
temp=args.temp,
|
|
||||||
top_p=args.top_p,
|
|
||||||
min_p=args.min_p,
|
|
||||||
min_tokens_to_keep=args.min_tokens_to_keep,
|
|
||||||
max_kv_size=args.max_kv_size,
|
max_kv_size=args.max_kv_size,
|
||||||
prompt_cache=prompt_cache if using_cache else None,
|
prompt_cache=prompt_cache if using_cache else None,
|
||||||
kv_bits=args.kv_bits,
|
kv_bits=args.kv_bits,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Dict, Optional
|
from typing import Callable, Dict, Optional
|
||||||
|
|
||||||
@ -80,7 +81,7 @@ def make_logits_processors(
|
|||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
def min_p_sampling(
|
def min_p_sampling(
|
||||||
logits: mx.array,
|
logprobs: mx.array,
|
||||||
min_p: float,
|
min_p: float,
|
||||||
min_tokens_to_keep: int = 1,
|
min_tokens_to_keep: int = 1,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
@ -93,7 +94,7 @@ def min_p_sampling(
|
|||||||
aggressive given a very high-probability token.
|
aggressive given a very high-probability token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits: The logits from the model's output.
|
logprobs: A vector of log probabilities.
|
||||||
min_p (float): Minimum token probability. Typical values are in the
|
min_p (float): Minimum token probability. Typical values are in the
|
||||||
0.01-0.2 range, comparably selective as setting `top_p` in the
|
0.01-0.2 range, comparably selective as setting `top_p` in the
|
||||||
0.99-0.8 range.
|
0.99-0.8 range.
|
||||||
@ -111,28 +112,27 @@ def min_p_sampling(
|
|||||||
)
|
)
|
||||||
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
|
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
|
||||||
|
|
||||||
# Softmax probabilities
|
logprobs = logprobs * (1 / temperature)
|
||||||
probs = mx.softmax(logits * (1 / temperature), axis=-1)
|
|
||||||
|
|
||||||
# Indices sorted in decreasing order
|
# Indices sorted in decreasing order
|
||||||
sorted_indices = mx.argsort(-logits).squeeze(0)
|
sorted_indices = mx.argsort(-logprobs).squeeze(0)
|
||||||
sorted_probs = probs[..., sorted_indices]
|
sorted_logprobs = logprobs[..., sorted_indices]
|
||||||
|
|
||||||
# Top probability
|
# Top probability
|
||||||
top_probs = probs[..., sorted_indices[0]]
|
top_logprobs = logprobs[..., sorted_indices[0]]
|
||||||
|
|
||||||
# Calculate the min_p threshold
|
# Calculate the min_p threshold
|
||||||
scaled_min_p = min_p * top_probs
|
scaled_min_p = top_logprobs + math.log(min_p)
|
||||||
|
|
||||||
# Mask tokens that have a probability less than the scaled min_p
|
# Mask tokens that have a probability less than the scaled min_p
|
||||||
tokens_to_remove = sorted_probs < scaled_min_p
|
tokens_to_remove = sorted_logprobs < scaled_min_p
|
||||||
tokens_to_remove[..., :min_tokens_to_keep] = False
|
tokens_to_remove[..., :min_tokens_to_keep] = False
|
||||||
|
|
||||||
# Create pool of tokens with probability less than scaled min_p
|
# Create pool of tokens with probability less than scaled min_p
|
||||||
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs)
|
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
|
||||||
|
|
||||||
# Return sampled token
|
# Return sampled token
|
||||||
sorted_token = mx.random.categorical(mx.log(selected_probs))
|
sorted_token = mx.random.categorical(selected_logprobs)
|
||||||
return sorted_indices[sorted_token]
|
return sorted_indices[sorted_token]
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ from huggingface_hub import scan_cache_dir
|
|||||||
|
|
||||||
from ._version import __version__
|
from ._version import __version__
|
||||||
from .models.cache import make_prompt_cache
|
from .models.cache import make_prompt_cache
|
||||||
|
from .sample_utils import make_logits_processors, make_sampler
|
||||||
from .utils import load, stream_generate
|
from .utils import load, stream_generate
|
||||||
|
|
||||||
|
|
||||||
@ -464,25 +465,24 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
for n, (segment, token, logprobs) in enumerate(
|
sampler = make_sampler(self.temperature)
|
||||||
stream_generate(
|
logits_processors = make_logits_processors(
|
||||||
model=self.model,
|
self.logit_bias, self.repetition_penalty, self.repetition_context_size
|
||||||
tokenizer=self.tokenizer,
|
)
|
||||||
prompt=prompt,
|
for gen_response in stream_generate(
|
||||||
max_tokens=self.max_tokens,
|
model=self.model,
|
||||||
temp=self.temperature,
|
tokenizer=self.tokenizer,
|
||||||
repetition_penalty=self.repetition_penalty,
|
prompt=prompt,
|
||||||
repetition_context_size=self.repetition_context_size,
|
max_tokens=self.max_tokens,
|
||||||
logit_bias=self.logit_bias,
|
sampler=sampler,
|
||||||
prompt_cache=self.prompt_cache.cache,
|
logits_processors=logits_processors,
|
||||||
),
|
prompt_cache=self.prompt_cache.cache,
|
||||||
):
|
):
|
||||||
if n == 0:
|
segment = gen_response.text
|
||||||
prompt_time = time.perf_counter() - tic
|
|
||||||
tic = time.perf_counter()
|
|
||||||
|
|
||||||
text += segment
|
text += segment
|
||||||
logging.debug(text)
|
logging.debug(text)
|
||||||
|
token = gen_response.token
|
||||||
|
logprobs = gen_response.logprobs
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
|
|
||||||
if self.logprobs > 0:
|
if self.logprobs > 0:
|
||||||
@ -523,13 +523,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
self.prompt_cache.tokens.extend(tokens)
|
self.prompt_cache.tokens.extend(tokens)
|
||||||
|
|
||||||
gen_time = time.perf_counter() - tic
|
logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec")
|
||||||
prompt_tps = len(prompt) / prompt_time
|
logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec")
|
||||||
gen_tps = len(tokens) / gen_time
|
logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB")
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
|
||||||
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
|
||||||
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
|
||||||
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
|
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
response = self.generate_response(segment, finish_reason)
|
response = self.generate_response(segment, finish_reason)
|
||||||
|
@ -73,16 +73,16 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
self._tokens = []
|
self.tokens = []
|
||||||
self._text = ""
|
self._text = ""
|
||||||
self._current_tokens = []
|
self._current_tokens = []
|
||||||
self._current_text = ""
|
self._current_text = ""
|
||||||
|
|
||||||
def add_token(self, token):
|
def add_token(self, token):
|
||||||
self._current_tokens.append(token)
|
self._current_tokens.append(token)
|
||||||
|
self.tokens.append(token)
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
self._tokens.extend(self._current_tokens)
|
|
||||||
self._text += self._tokenizer.decode(self._current_tokens)
|
self._text += self._tokenizer.decode(self._current_tokens)
|
||||||
self._current_tokens = []
|
self._current_tokens = []
|
||||||
self._current_text = ""
|
self._current_text = ""
|
||||||
@ -97,16 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
):
|
):
|
||||||
self._current_text = self._current_text[:-1]
|
self._current_text = self._current_text[:-1]
|
||||||
if self._current_text and self._current_text[-1] == "\n":
|
if self._current_text and self._current_text[-1] == "\n":
|
||||||
self._tokens.extend(self._current_tokens)
|
|
||||||
self._text += self._current_text
|
self._text += self._current_text
|
||||||
self._current_tokens.clear()
|
self._current_tokens.clear()
|
||||||
self._current_text = ""
|
self._current_text = ""
|
||||||
return self._text + self._current_text
|
return self._text + self._current_text
|
||||||
|
|
||||||
@property
|
|
||||||
def tokens(self):
|
|
||||||
return self._tokens
|
|
||||||
|
|
||||||
|
|
||||||
class SPMStreamingDetokenizer(StreamingDetokenizer):
|
class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||||
"""A streaming detokenizer for SPM models.
|
"""A streaming detokenizer for SPM models.
|
||||||
@ -143,6 +138,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
self.text += text
|
self.text += text
|
||||||
|
|
||||||
def add_token(self, token):
|
def add_token(self, token):
|
||||||
|
self.tokens.append(token)
|
||||||
v = self.tokenmap[token]
|
v = self.tokenmap[token]
|
||||||
if v.startswith(self._sep):
|
if v.startswith(self._sep):
|
||||||
self._flush()
|
self._flush()
|
||||||
@ -200,6 +196,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
return current_text
|
return current_text
|
||||||
|
|
||||||
def add_token(self, token):
|
def add_token(self, token):
|
||||||
|
self.tokens.append(token)
|
||||||
v = self.tokenmap[token]
|
v = self.tokenmap[token]
|
||||||
is_added = token in self._added_ids
|
is_added = token in self._added_ids
|
||||||
if is_added or self._byte_decoder[v[0]] == 32:
|
if is_added or self._byte_decoder[v[0]] == 32:
|
||||||
|
@ -8,6 +8,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
||||||
@ -44,6 +45,32 @@ class ModelNotFoundError(Exception):
|
|||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationResponse:
|
||||||
|
"""
|
||||||
|
The output of :func:`stream_generate`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The next segment of decoded text. This can be an empty string.
|
||||||
|
token (int): The next token.
|
||||||
|
logprobs (mx.array): A vector of log probabilities.
|
||||||
|
prompt_tokens (int): The number of tokens in the prompt.
|
||||||
|
prompt_tps (float): The prompt processing tokens-per-second.
|
||||||
|
generation_tokens (int): The number of generated tokens.
|
||||||
|
generation_tps (float): The tokens-per-second for generation.
|
||||||
|
peak_memory (float): The peak memory used so far in GB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
token: int
|
||||||
|
logprobs: mx.array
|
||||||
|
prompt_tokens: int
|
||||||
|
prompt_tps: float
|
||||||
|
generation_tokens: int
|
||||||
|
generation_tps: float
|
||||||
|
peak_memory: float
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
||||||
"""
|
"""
|
||||||
@ -155,20 +182,21 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
|
|||||||
def generate_step(
|
def generate_step(
|
||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
temp: float = 0.0,
|
*,
|
||||||
repetition_penalty: Optional[float] = None,
|
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||||
repetition_context_size: Optional[int] = 20,
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||||
top_p: float = 1.0,
|
|
||||||
min_p: float = 0.0,
|
|
||||||
min_tokens_to_keep: int = 1,
|
|
||||||
prefill_step_size: int = 512,
|
|
||||||
max_kv_size: Optional[int] = None,
|
max_kv_size: Optional[int] = None,
|
||||||
prompt_cache: Optional[Any] = None,
|
prompt_cache: Optional[Any] = None,
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
prefill_step_size: int = 512,
|
||||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
|
||||||
kv_bits: Optional[int] = None,
|
kv_bits: Optional[int] = None,
|
||||||
kv_group_size: int = 64,
|
kv_group_size: int = 64,
|
||||||
quantized_kv_start: int = 0,
|
quantized_kv_start: int = 0,
|
||||||
|
temp: Optional[float] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
repetition_context_size: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
min_p: Optional[float] = None,
|
||||||
|
min_tokens_to_keep: Optional[int] = None,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
@ -176,32 +204,21 @@ def generate_step(
|
|||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The input prompt.
|
prompt (mx.array): The input prompt.
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
|
||||||
Default: ``0``.
|
|
||||||
repetition_penalty (float, optional): The penalty factor for repeating
|
|
||||||
tokens.
|
|
||||||
repetition_context_size (int, optional): The number of tokens to
|
|
||||||
consider for repetition penalty. Default: ``20``.
|
|
||||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
|
||||||
more less likely words.
|
|
||||||
min_p (float, optional): The minimum value (scaled by the top token's
|
|
||||||
probability) that a token probability must have to be considered.
|
|
||||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
|
||||||
be filtered by min_p sampling.
|
|
||||||
prefill_step_size (int): Step size for processing the prompt.
|
prefill_step_size (int): Step size for processing the prompt.
|
||||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||||
entries (except the first 4 tokens) will be overwritten.
|
entries (except the first 4 tokens) will be overwritten.
|
||||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||||
provided, the cache will be updated in place.
|
provided, the cache will be updated in place.
|
||||||
logit_bias (dictionary, optional): Additive logit bias.
|
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||||
|
token from a vector of log probabilities. Default: ``None``.
|
||||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||||
A list of functions that take tokens and logits and return the processed
|
A list of functions that take tokens and logits and return the processed
|
||||||
logits. Default: ``None``.
|
logits. Default: ``None``.
|
||||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||||
None implies no cache quantization. Default: ``None``.
|
None implies no cache quantization. Default: ``None``.
|
||||||
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||||
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||||
when ``kv_bits`` is non-None. Default: ``0``.
|
when ``kv_bits`` is non-None. Default: ``0``.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||||
@ -219,10 +236,22 @@ def generate_step(
|
|||||||
elif len(prompt_cache) != len(model.layers):
|
elif len(prompt_cache) != len(model.layers):
|
||||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||||
|
|
||||||
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
|
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
|
||||||
logits_processors = logits_processors or []
|
print(
|
||||||
logits_processors.extend(
|
"[Warning] Specifying sampling arguments to ``generate_step`` is "
|
||||||
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
|
"deprecated. Pass in a ``sampler`` instead."
|
||||||
|
)
|
||||||
|
if repetition_penalty is not None:
|
||||||
|
print(
|
||||||
|
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
|
||||||
|
"Pass in ``logits_processors`` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler = sampler or make_sampler(
|
||||||
|
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
|
||||||
|
)
|
||||||
|
logits_processors = logits_processors or make_logits_processors(
|
||||||
|
None, repetition_penalty, repetition_context_size or 20
|
||||||
)
|
)
|
||||||
|
|
||||||
def _step(y):
|
def _step(y):
|
||||||
@ -290,17 +319,20 @@ def stream_generate(
|
|||||||
if not isinstance(tokenizer, TokenizerWrapper):
|
if not isinstance(tokenizer, TokenizerWrapper):
|
||||||
tokenizer = TokenizerWrapper(tokenizer)
|
tokenizer = TokenizerWrapper(tokenizer)
|
||||||
|
|
||||||
prompt_tokens = mx.array(
|
prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt))
|
||||||
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
|
||||||
)
|
|
||||||
detokenizer = tokenizer.detokenizer
|
detokenizer = tokenizer.detokenizer
|
||||||
|
|
||||||
with wired_limit(model, [generation_stream]):
|
with wired_limit(model, [generation_stream]):
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
for n, (token, logits) in zip(
|
tic = time.perf_counter()
|
||||||
|
for n, (token, logprobs) in zip(
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
generate_step(prompt, model, **kwargs),
|
||||||
):
|
):
|
||||||
|
if n == 0:
|
||||||
|
prompt_time = time.perf_counter() - tic
|
||||||
|
prompt_tps = prompt.size / prompt_time
|
||||||
|
tic = time.perf_counter()
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -309,17 +341,34 @@ def stream_generate(
|
|||||||
if n == (max_tokens - 1):
|
if n == (max_tokens - 1):
|
||||||
break
|
break
|
||||||
|
|
||||||
yield detokenizer.last_segment, token, logits
|
yield GenerationResponse(
|
||||||
|
text=detokenizer.last_segment,
|
||||||
|
token=token,
|
||||||
|
logprobs=logprobs,
|
||||||
|
prompt_tokens=prompt.size,
|
||||||
|
prompt_tps=prompt_tps,
|
||||||
|
generation_tokens=n + 1,
|
||||||
|
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||||
|
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
||||||
|
)
|
||||||
|
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
yield detokenizer.last_segment, token, logits
|
yield GenerationResponse(
|
||||||
|
text=detokenizer.last_segment,
|
||||||
|
token=token,
|
||||||
|
logprobs=logprobs,
|
||||||
|
prompt_tokens=prompt.size,
|
||||||
|
prompt_tps=prompt_tps,
|
||||||
|
generation_tokens=n + 1,
|
||||||
|
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||||
|
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
max_tokens: int = 100,
|
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
formatter: Optional[Callable] = None,
|
formatter: Optional[Callable] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -334,64 +383,40 @@ def generate(
|
|||||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||||
verbose (bool): If ``True``, print tokens and timing information.
|
verbose (bool): If ``True``, print tokens and timing information.
|
||||||
Default: ``False``.
|
Default: ``False``.
|
||||||
formatter (Optional[Callable]): A function which takes a token and a
|
kwargs: The remaining options get passed to :func:`stream_generate`.
|
||||||
probability and displays it.
|
See :func:`stream_generate` for more details.
|
||||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
|
||||||
See :func:`generate_step` for more details.
|
|
||||||
"""
|
"""
|
||||||
if not isinstance(tokenizer, TokenizerWrapper):
|
if formatter is not None:
|
||||||
tokenizer = TokenizerWrapper(tokenizer)
|
print(
|
||||||
|
"[Warning] Text formatting is deprecated and no longer used. "
|
||||||
|
"The argument will be removed in a future version."
|
||||||
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
print("Prompt:", prompt)
|
print("Prompt:", prompt)
|
||||||
|
|
||||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
text = ""
|
||||||
detokenizer = tokenizer.detokenizer
|
for response in stream_generate(model, tokenizer, prompt, **kwargs):
|
||||||
|
|
||||||
with wired_limit(model, [generation_stream]):
|
|
||||||
tic = time.perf_counter()
|
|
||||||
detokenizer.reset()
|
|
||||||
for n, (token, logprobs) in zip(
|
|
||||||
range(max_tokens),
|
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
|
||||||
):
|
|
||||||
if n == 0:
|
|
||||||
prompt_time = time.perf_counter() - tic
|
|
||||||
tic = time.perf_counter()
|
|
||||||
if token == tokenizer.eos_token_id:
|
|
||||||
break
|
|
||||||
detokenizer.add_token(token)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
if formatter:
|
|
||||||
# We have to finalize so that the prob corresponds to the last segment
|
|
||||||
detokenizer.finalize()
|
|
||||||
prob = mx.exp(logprobs[token]).item()
|
|
||||||
formatter(detokenizer.last_segment, prob)
|
|
||||||
else:
|
|
||||||
print(detokenizer.last_segment, end="", flush=True)
|
|
||||||
|
|
||||||
token_count = n + 1
|
|
||||||
detokenizer.finalize()
|
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
gen_time = time.perf_counter() - tic
|
print(response.text, end="", flush=True)
|
||||||
print(detokenizer.last_segment, flush=True)
|
text += response.text
|
||||||
print("=" * 10)
|
|
||||||
if token_count == 0:
|
|
||||||
print("No tokens generated for this prompt")
|
|
||||||
return
|
|
||||||
prompt_tps = prompt_tokens.size / prompt_time
|
|
||||||
gen_tps = (token_count - 1) / gen_time
|
|
||||||
print(
|
|
||||||
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
|
|
||||||
)
|
|
||||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
|
||||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
|
||||||
|
|
||||||
return detokenizer.text
|
if verbose:
|
||||||
|
print()
|
||||||
|
print("=" * 10)
|
||||||
|
if len(text) == 0:
|
||||||
|
print("No text generated for this prompt")
|
||||||
|
return
|
||||||
|
print(
|
||||||
|
f"Prompt: {response.prompt_tokens} tokens, "
|
||||||
|
f"{response.prompt_tps:.3f} tokens-per-sec"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Generation: {response.generation_tokens} tokens, "
|
||||||
|
f"{response.generation_tps:.3f} tokens-per-sec"
|
||||||
|
)
|
||||||
|
print(f"Peak memory: {response.peak_memory:.3f} GB")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def load_config(model_path: Path) -> dict:
|
def load_config(model_path: Path) -> dict:
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from mlx_lm.sample_utils import make_logits_processors
|
||||||
from mlx_lm.utils import generate, load
|
from mlx_lm.utils import generate, load
|
||||||
|
|
||||||
|
|
||||||
@ -25,8 +26,8 @@ class TestGenerate(unittest.TestCase):
|
|||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
"hello",
|
"hello",
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
|
logits_processors=make_logits_processors(logit_bias),
|
||||||
verbose=False,
|
verbose=False,
|
||||||
logit_bias=logit_bias,
|
|
||||||
)
|
)
|
||||||
self.assertEqual(text, "!!!!!")
|
self.assertEqual(text, "!!!!!")
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx_lm.sample_utils import top_p_sampling
|
from mlx_lm.sample_utils import min_p_sampling, top_p_sampling
|
||||||
|
|
||||||
|
|
||||||
class TestSamplingUtils(unittest.TestCase):
|
class TestSampleUtils(unittest.TestCase):
|
||||||
def test_top_p_sampling(self):
|
def test_top_p_sampling(self):
|
||||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||||
logits = mx.log(probs)
|
logits = mx.log(probs)
|
||||||
@ -28,6 +28,20 @@ class TestSamplingUtils(unittest.TestCase):
|
|||||||
token = top_p_sampling(logits, 0.95, temperature).item()
|
token = top_p_sampling(logits, 0.95, temperature).item()
|
||||||
self.assertTrue(token in (1, 2, 3))
|
self.assertTrue(token in (1, 2, 3))
|
||||||
|
|
||||||
|
def test_min_p_sampling(self):
|
||||||
|
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||||
|
logits = mx.log(probs)
|
||||||
|
temperature = 1.0
|
||||||
|
token = min_p_sampling(logits, 0.8)
|
||||||
|
self.assertEqual(token, 0)
|
||||||
|
|
||||||
|
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||||
|
logits = mx.log(probs)
|
||||||
|
temperature = 1.0
|
||||||
|
for _ in range(5):
|
||||||
|
token = min_p_sampling(logits, 0.05)
|
||||||
|
self.assertTrue(token in (0, 3))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -34,10 +34,11 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
detokenizer = tokenizer.detokenizer
|
detokenizer = tokenizer.detokenizer
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
text = ""
|
text = ""
|
||||||
for t in tokens:
|
for e, t in enumerate(tokens):
|
||||||
detokenizer.add_token(t)
|
detokenizer.add_token(t)
|
||||||
seg = detokenizer.last_segment
|
seg = detokenizer.last_segment
|
||||||
text += seg
|
text += seg
|
||||||
|
self.assertEqual(detokenizer.tokens, tokens[: e + 1])
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
text += detokenizer.last_segment
|
text += detokenizer.last_segment
|
||||||
self.assertEqual(text, expected_text)
|
self.assertEqual(text, expected_text)
|
||||||
|
Loading…
Reference in New Issue
Block a user