mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
some cleanup, warnings, tests
This commit is contained in:
parent
9986787303
commit
f82e49aad9
@ -100,8 +100,9 @@ To see a description of all the arguments you can do:
|
|||||||
|
|
||||||
#### Streaming
|
#### Streaming
|
||||||
|
|
||||||
For streaming generation, use the `stream_generate` function. This returns a
|
For streaming generation, use the `stream_generate` function. This yields
|
||||||
generator object which streams the output text, token, and log probabilities.
|
a generation response object.
|
||||||
|
|
||||||
For example,
|
For example,
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -6,6 +6,7 @@ import json
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .models.cache import make_prompt_cache
|
from .models.cache import make_prompt_cache
|
||||||
|
from .sample_utils import make_sampler
|
||||||
from .utils import load, stream_generate
|
from .utils import load, stream_generate
|
||||||
|
|
||||||
DEFAULT_TEMP = 0.0
|
DEFAULT_TEMP = 0.0
|
||||||
@ -79,8 +80,7 @@ def main():
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
args.max_tokens,
|
args.max_tokens,
|
||||||
temp=args.temp,
|
sampler=make_sampler(args.temp, args.top_p),
|
||||||
top_p=args.top_p,
|
|
||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
):
|
):
|
||||||
print(response.text, flush=True, end="")
|
print(response.text, flush=True, end="")
|
||||||
|
@ -42,7 +42,6 @@ response = generate(
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
temp=0.0,
|
|
||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,14 +23,6 @@ max_tokens = 1_000
|
|||||||
# Specify if tokens and timing information will be printed
|
# Specify if tokens and timing information will be printed
|
||||||
verbose = True
|
verbose = True
|
||||||
|
|
||||||
# Some optional arguments for causal language model generation
|
|
||||||
generation_args = {
|
|
||||||
"temp": 0.7,
|
|
||||||
"repetition_penalty": 1.2,
|
|
||||||
"repetition_context_size": 20,
|
|
||||||
"top_p": 0.95,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Generate a response with the specified settings
|
# Generate a response with the specified settings
|
||||||
response = generate(
|
response = generate(
|
||||||
model=model,
|
model=model,
|
||||||
@ -38,5 +30,4 @@ response = generate(
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
**generation_args,
|
|
||||||
)
|
)
|
||||||
|
@ -7,6 +7,7 @@ import sys
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .models.cache import QuantizedKVCache, load_prompt_cache
|
from .models.cache import QuantizedKVCache, load_prompt_cache
|
||||||
|
from .sample_utils import make_sampler
|
||||||
from .utils import generate, load
|
from .utils import generate, load
|
||||||
|
|
||||||
DEFAULT_PROMPT = "hello"
|
DEFAULT_PROMPT = "hello"
|
||||||
@ -218,16 +219,14 @@ def main():
|
|||||||
else:
|
else:
|
||||||
prompt = args.prompt
|
prompt = args.prompt
|
||||||
|
|
||||||
|
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||||
response = generate(
|
response = generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
max_tokens=args.max_tokens,
|
max_tokens=args.max_tokens,
|
||||||
verbose=args.verbose,
|
verbose=args.verbose,
|
||||||
temp=args.temp,
|
sampler=sampler,
|
||||||
top_p=args.top_p,
|
|
||||||
min_p=args.min_p,
|
|
||||||
min_tokens_to_keep=args.min_tokens_to_keep,
|
|
||||||
max_kv_size=args.max_kv_size,
|
max_kv_size=args.max_kv_size,
|
||||||
prompt_cache=prompt_cache if using_cache else None,
|
prompt_cache=prompt_cache if using_cache else None,
|
||||||
kv_bits=args.kv_bits,
|
kv_bits=args.kv_bits,
|
||||||
|
@ -27,6 +27,7 @@ from huggingface_hub import scan_cache_dir
|
|||||||
|
|
||||||
from ._version import __version__
|
from ._version import __version__
|
||||||
from .models.cache import make_prompt_cache
|
from .models.cache import make_prompt_cache
|
||||||
|
from .sample_utils import make_logits_processors, make_sampler
|
||||||
from .utils import load, stream_generate
|
from .utils import load, stream_generate
|
||||||
|
|
||||||
|
|
||||||
@ -464,15 +465,17 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
|
sampler = make_sampler(self.temperature)
|
||||||
|
logits_processors = make_logits_processors(
|
||||||
|
self.logit_bias, self.repetition_penalty, self.repetition_context_size
|
||||||
|
)
|
||||||
for gen_response in stream_generate(
|
for gen_response in stream_generate(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
temp=self.temperature,
|
sampler=sampler,
|
||||||
repetition_penalty=self.repetition_penalty,
|
logits_processors=logits_processors,
|
||||||
repetition_context_size=self.repetition_context_size,
|
|
||||||
logit_bias=self.logit_bias,
|
|
||||||
prompt_cache=self.prompt_cache.cache,
|
prompt_cache=self.prompt_cache.cache,
|
||||||
):
|
):
|
||||||
segment = gen_response.text
|
segment = gen_response.text
|
||||||
|
@ -73,16 +73,16 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
self._tokens = []
|
self.tokens = []
|
||||||
self._text = ""
|
self._text = ""
|
||||||
self._current_tokens = []
|
self._current_tokens = []
|
||||||
self._current_text = ""
|
self._current_text = ""
|
||||||
|
|
||||||
def add_token(self, token):
|
def add_token(self, token):
|
||||||
self._current_tokens.append(token)
|
self._current_tokens.append(token)
|
||||||
|
self.tokens.append(token)
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
self._tokens.extend(self._current_tokens)
|
|
||||||
self._text += self._tokenizer.decode(self._current_tokens)
|
self._text += self._tokenizer.decode(self._current_tokens)
|
||||||
self._current_tokens = []
|
self._current_tokens = []
|
||||||
self._current_text = ""
|
self._current_text = ""
|
||||||
@ -97,16 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
):
|
):
|
||||||
self._current_text = self._current_text[:-1]
|
self._current_text = self._current_text[:-1]
|
||||||
if self._current_text and self._current_text[-1] == "\n":
|
if self._current_text and self._current_text[-1] == "\n":
|
||||||
self._tokens.extend(self._current_tokens)
|
|
||||||
self._text += self._current_text
|
self._text += self._current_text
|
||||||
self._current_tokens.clear()
|
self._current_tokens.clear()
|
||||||
self._current_text = ""
|
self._current_text = ""
|
||||||
return self._text + self._current_text
|
return self._text + self._current_text
|
||||||
|
|
||||||
@property
|
|
||||||
def tokens(self):
|
|
||||||
return self._tokens
|
|
||||||
|
|
||||||
|
|
||||||
class SPMStreamingDetokenizer(StreamingDetokenizer):
|
class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||||
"""A streaming detokenizer for SPM models.
|
"""A streaming detokenizer for SPM models.
|
||||||
@ -143,6 +138,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
self.text += text
|
self.text += text
|
||||||
|
|
||||||
def add_token(self, token):
|
def add_token(self, token):
|
||||||
|
self.tokens.append(token)
|
||||||
v = self.tokenmap[token]
|
v = self.tokenmap[token]
|
||||||
if v.startswith(self._sep):
|
if v.startswith(self._sep):
|
||||||
self._flush()
|
self._flush()
|
||||||
@ -200,6 +196,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
return current_text
|
return current_text
|
||||||
|
|
||||||
def add_token(self, token):
|
def add_token(self, token):
|
||||||
|
self.tokens.append(token)
|
||||||
v = self.tokenmap[token]
|
v = self.tokenmap[token]
|
||||||
is_added = token in self._added_ids
|
is_added = token in self._added_ids
|
||||||
if is_added or self._byte_decoder[v[0]] == 32:
|
if is_added or self._byte_decoder[v[0]] == 32:
|
||||||
|
@ -182,20 +182,21 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
|
|||||||
def generate_step(
|
def generate_step(
|
||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
temp: float = 0.0,
|
*,
|
||||||
repetition_penalty: Optional[float] = None,
|
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||||
repetition_context_size: Optional[int] = 20,
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||||
top_p: float = 1.0,
|
|
||||||
min_p: float = 0.0,
|
|
||||||
min_tokens_to_keep: int = 1,
|
|
||||||
prefill_step_size: int = 512,
|
|
||||||
max_kv_size: Optional[int] = None,
|
max_kv_size: Optional[int] = None,
|
||||||
prompt_cache: Optional[Any] = None,
|
prompt_cache: Optional[Any] = None,
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
prefill_step_size: int = 512,
|
||||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
|
||||||
kv_bits: Optional[int] = None,
|
kv_bits: Optional[int] = None,
|
||||||
kv_group_size: int = 64,
|
kv_group_size: int = 64,
|
||||||
quantized_kv_start: int = 0,
|
quantized_kv_start: int = 0,
|
||||||
|
temp: Optional[float] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
repetition_context_size: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
min_p: Optional[float] = None,
|
||||||
|
min_tokens_to_keep: Optional[int] = None,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
@ -203,32 +204,21 @@ def generate_step(
|
|||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The input prompt.
|
prompt (mx.array): The input prompt.
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
|
||||||
Default: ``0``.
|
|
||||||
repetition_penalty (float, optional): The penalty factor for repeating
|
|
||||||
tokens.
|
|
||||||
repetition_context_size (int, optional): The number of tokens to
|
|
||||||
consider for repetition penalty. Default: ``20``.
|
|
||||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
|
||||||
more less likely words.
|
|
||||||
min_p (float, optional): The minimum value (scaled by the top token's
|
|
||||||
probability) that a token probability must have to be considered.
|
|
||||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
|
||||||
be filtered by min_p sampling.
|
|
||||||
prefill_step_size (int): Step size for processing the prompt.
|
prefill_step_size (int): Step size for processing the prompt.
|
||||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||||
entries (except the first 4 tokens) will be overwritten.
|
entries (except the first 4 tokens) will be overwritten.
|
||||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||||
provided, the cache will be updated in place.
|
provided, the cache will be updated in place.
|
||||||
logit_bias (dictionary, optional): Additive logit bias.
|
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||||
|
token from a vector of log probabilities. Default: ``None``.
|
||||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||||
A list of functions that take tokens and logits and return the processed
|
A list of functions that take tokens and logits and return the processed
|
||||||
logits. Default: ``None``.
|
logits. Default: ``None``.
|
||||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||||
None implies no cache quantization. Default: ``None``.
|
None implies no cache quantization. Default: ``None``.
|
||||||
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||||
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||||
when ``kv_bits`` is non-None. Default: ``0``.
|
when ``kv_bits`` is non-None. Default: ``0``.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||||
@ -246,10 +236,22 @@ def generate_step(
|
|||||||
elif len(prompt_cache) != len(model.layers):
|
elif len(prompt_cache) != len(model.layers):
|
||||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||||
|
|
||||||
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
|
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
|
||||||
logits_processors = logits_processors or []
|
print(
|
||||||
logits_processors.extend(
|
"[Warning] Specifying sampling arguments to ``generate_step`` is "
|
||||||
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
|
"deprecated. Pass in a ``sampler`` instead."
|
||||||
|
)
|
||||||
|
if repetition_penalty is not None:
|
||||||
|
print(
|
||||||
|
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
|
||||||
|
"Pass in ``logits_processors`` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler = sampler or make_sampler(
|
||||||
|
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
|
||||||
|
)
|
||||||
|
logits_processors = logits_processors or make_logits_processors(
|
||||||
|
None, repetition_penalty, repetition_context_size or 20
|
||||||
)
|
)
|
||||||
|
|
||||||
def _step(y):
|
def _step(y):
|
||||||
@ -385,7 +387,10 @@ def generate(
|
|||||||
See :func:`stream_generate` for more details.
|
See :func:`stream_generate` for more details.
|
||||||
"""
|
"""
|
||||||
if formatter is not None:
|
if formatter is not None:
|
||||||
print("Text formatting is deprecated and will be removed in the next version.")
|
print(
|
||||||
|
"[Warning] Text formatting is deprecated and no longer used. "
|
||||||
|
"The argument will be removed in a future version."
|
||||||
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
print("Prompt:", prompt)
|
print("Prompt:", prompt)
|
||||||
|
@ -34,10 +34,11 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
detokenizer = tokenizer.detokenizer
|
detokenizer = tokenizer.detokenizer
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
text = ""
|
text = ""
|
||||||
for t in tokens:
|
for e, t in enumerate(tokens):
|
||||||
detokenizer.add_token(t)
|
detokenizer.add_token(t)
|
||||||
seg = detokenizer.last_segment
|
seg = detokenizer.last_segment
|
||||||
text += seg
|
text += seg
|
||||||
|
self.assertEqual(detokenizer.tokens, tokens[: e + 1])
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
text += detokenizer.last_segment
|
text += detokenizer.last_segment
|
||||||
self.assertEqual(text, expected_text)
|
self.assertEqual(text, expected_text)
|
||||||
|
Loading…
Reference in New Issue
Block a user